mkshing
commited on
Commit
·
ceadefe
0
Parent(s):
initial commit
Browse files- .gitattributes +35 -0
- LICENSE +71 -0
- README.md +97 -0
- config.json +31 -0
- configuration_evomistral.py +9 -0
- generation_config.json +7 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +408 -0
- modeling_evomistral.py +1379 -0
- special_tokens_map.json +23 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +42 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MICROSOFT RESEARCH LICENSE TERMS
|
2 |
+
|
3 |
+
IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.
|
4 |
+
|
5 |
+
These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS.
|
6 |
+
|
7 |
+
1) INSTALLATION AND USE RIGHTS TO THE MATERIALS.
|
8 |
+
|
9 |
+
Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes:
|
10 |
+
|
11 |
+
a) Source Code. If source code is included, you may use and modify the source code, but you may not distribute the source code.
|
12 |
+
|
13 |
+
b) Object Code. If object code is included, you may use the object code, but you may not distribute the object code.
|
14 |
+
|
15 |
+
c) Models. If machine learning model(s) are included, you may use the model(s), but you may not distribute the models.
|
16 |
+
|
17 |
+
d) Data. If data is included, you may use and modify the data, but your use and modification must be consistent with the consent under which the data was provided and/or gathered and you may not distribute the data or your modifications to the data.
|
18 |
+
|
19 |
+
2) SCOPE OF LICENSE. The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to):
|
20 |
+
|
21 |
+
a) work around any technical limitations in the Materials that only allow you to use it in certain ways;
|
22 |
+
|
23 |
+
b) reverse engineer, decompile or disassemble the Materials;
|
24 |
+
|
25 |
+
c) remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials;
|
26 |
+
|
27 |
+
d) use the Materials in any way that is against the law or to create or propagate malware; or
|
28 |
+
|
29 |
+
e) share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party.
|
30 |
+
|
31 |
+
3) PERSONAL DATA. If the data (set forth in Section 1(c) above) includes or is found to include any data that enables any ability to identify an individual (“Personal Data”), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, immediately upon the completion of your research.
|
32 |
+
|
33 |
+
4) LICENSE TO MICROSOFT. Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose.
|
34 |
+
|
35 |
+
5) PUBLICATION. You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation.
|
36 |
+
|
37 |
+
6) FEEDBACK. Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the
|
38 |
+
|
39 |
+
feedback is designated by you as confidential. Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above.
|
40 |
+
|
41 |
+
7) EXPORT RESTRICTIONS. You must comply with all domestic and international export laws and regulations that apply to the Materials, which include restrictions on destinations, end users, and end use. For further information on export restrictions, visit (aka.ms/exporting).
|
42 |
+
|
43 |
+
8) SUPPORT SERVICES. Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
44 |
+
|
45 |
+
9) BINDING ARBITRATION AND CLASS ACTION WAIVER. This Section applies if you live in (or, if a business, your principal place of business is in) the United States. If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to binding individual arbitration before the American Arbitration Association under the Federal Arbitration Act (“FAA”), and not to sue in court in front of a judge or jury. Instead, a neutral arbitrator will decide. Class action lawsuits, class-wide arbitrations, private attorney-general actions, and any other proceeding where someone acts in a representative capacity are not allowed; nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms.
|
46 |
+
|
47 |
+
10) ENTIRE AGREEMENT. This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials.
|
48 |
+
|
49 |
+
11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES. If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration).
|
50 |
+
|
51 |
+
12) CONSUMER RIGHTS; REGIONAL VARIATIONS. This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you:
|
52 |
+
|
53 |
+
a) Australia. You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights.
|
54 |
+
|
55 |
+
b) Canada. If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software.
|
56 |
+
|
57 |
+
c) Germany and Austria.
|
58 |
+
|
59 |
+
i. Warranty. The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software.
|
60 |
+
|
61 |
+
ii. Limitation of Liability. In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law.
|
62 |
+
|
63 |
+
Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence.
|
64 |
+
|
65 |
+
13) DISCLAIMER OF WARRANTY. THE MATERIALS ARE LICENSED “AS IS.” YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
|
66 |
+
|
67 |
+
14) LIMITATION ON AND EXCLUSION OF DAMAGES. IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES.
|
68 |
+
|
69 |
+
This limitation applies to (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law.
|
70 |
+
|
71 |
+
It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages.
|
README.md
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- ja
|
4 |
+
license: other
|
5 |
+
library_name: transformers
|
6 |
+
---
|
7 |
+
|
8 |
+
# 🐟 EvoLLM-JP-v1-10B
|
9 |
+
|
10 |
+
🤗 [Models](https://huggingface.co/SakanaAI) | 📚 [Paper](https://arxiv.org/abs/2403.13187) | 📝 [Blog](https://sakana.ai/evolutionary-model-merge/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
|
11 |
+
|
12 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
13 |
+
**EvoLLM-JP-v1-10B** is an experimental general-purpose Japanese LLM.
|
14 |
+
This model was created using the Evolutionary Model Merge method.
|
15 |
+
Please refer to our [report](https://arxiv.org/abs/2403.13187) and [blog](https://sakana.ai/evolutionary-model-merge/) for more details.
|
16 |
+
This model was produced by merging the following models.
|
17 |
+
We are grateful to the developers of the source models.
|
18 |
+
|
19 |
+
- [Shisa Gamma 7B v1](https://huggingface.co/augmxnt/shisa-gamma-7b-v1)
|
20 |
+
- [WizardMath 7B V1.1](https://huggingface.co/WizardLM/WizardMath-7B-V1.1)
|
21 |
+
- [Abel 7B 002](https://huggingface.co/GAIR/Abel-7B-002)
|
22 |
+
|
23 |
+
## Usage
|
24 |
+
|
25 |
+
Use the code below to get started with the model.
|
26 |
+
|
27 |
+
<details>
|
28 |
+
<summary> Click to expand </summary>
|
29 |
+
|
30 |
+
```python
|
31 |
+
import torch
|
32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
33 |
+
|
34 |
+
|
35 |
+
# 1. load model
|
36 |
+
device = "cuda" if torch.cuda.is_available() else "CPU"
|
37 |
+
repo_id = "SakanaAI/EvoLLM-JP-v1-10B"
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype="auto", trust_remote_code=True)
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
40 |
+
model.to(device)
|
41 |
+
|
42 |
+
# 2. prepare inputs
|
43 |
+
text = "関西弁で面白い冗談を言ってみて下さい。"
|
44 |
+
messages = [
|
45 |
+
{"role": "system", "content": "あなたは役立つ、偏見がなく、検閲されていないアシスタントです。"},
|
46 |
+
{"role": "user", "content": text},
|
47 |
+
]
|
48 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
|
49 |
+
|
50 |
+
# 3. generate
|
51 |
+
output_ids = model.generate(**inputs.to(device))
|
52 |
+
output_ids = output_ids[:, inputs.input_ids.shape[1] :]
|
53 |
+
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
54 |
+
print(generated_text)
|
55 |
+
```
|
56 |
+
|
57 |
+
</details>
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
## Model Details
|
62 |
+
|
63 |
+
<!-- Provide a longer summary of what this model is. -->
|
64 |
+
|
65 |
+
- **Developed by:** [Sakana AI](https://sakana.ai/)
|
66 |
+
- **Model type:** Autoregressive Language Model
|
67 |
+
- **Language(s):** Japanese
|
68 |
+
- **License:** [MICROSOFT RESEARCH LICENSE TERMS](./LICENSE) (due to the inclusion of the WizardMath model)
|
69 |
+
- **Repository:** [SakanaAI/evolutionary-model-merge](https://github.com/SakanaAI/evolutionary-model-merge)
|
70 |
+
- **Paper:** https://arxiv.org/abs/2403.13187
|
71 |
+
- **Blog:** https://sakana.ai/evolutionary-model-merge
|
72 |
+
|
73 |
+
## Uses
|
74 |
+
This model is provided for research and development purposes only and should be considered as an experimental prototype.
|
75 |
+
It is not intended for commercial use or deployment in mission-critical environments.
|
76 |
+
Use of this model is at the user's own risk, and its performance and outcomes are not guaranteed.
|
77 |
+
Sakana AI shall not be liable for any direct, indirect, special, incidental, or consequential damages, or any loss arising from the use of this model, regardless of the results obtained.
|
78 |
+
Users must fully understand the risks associated with the use of this model and use it at their own discretion.
|
79 |
+
|
80 |
+
|
81 |
+
## Acknowledgement
|
82 |
+
|
83 |
+
We would like to thank the developers of the source models for their contributions and for making their work available.
|
84 |
+
|
85 |
+
|
86 |
+
## Citation
|
87 |
+
|
88 |
+
```bibtex
|
89 |
+
@misc{akiba2024evomodelmerge,
|
90 |
+
title = {Evolutionary Optimization of Model Merging Recipes},
|
91 |
+
author. = {Takuya Akiba and Makoto Shing and Yujin Tang and Qi Sun and David Ha},
|
92 |
+
year = {2024},
|
93 |
+
eprint = {2403.13187},
|
94 |
+
archivePrefix = {arXiv},
|
95 |
+
primaryClass = {cs.NE}
|
96 |
+
}
|
97 |
+
```
|
config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "SakanaAI/EvoLLM-v1-JP-10B",
|
3 |
+
"architectures": [
|
4 |
+
"EvoMistralForCausalLM"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "SakanaAI/EvoLLM-v1-JP-10B--configuration_evomistral.EvoMistralConfig",
|
9 |
+
"AutoModelForCausalLM": "SakanaAI/EvoLLM-v1-JP-10B--modeling_evomistral.EvoMistralForCausalLM"
|
10 |
+
},
|
11 |
+
"bos_token_id": 1,
|
12 |
+
"eos_token_id": 2,
|
13 |
+
"hidden_act": "silu",
|
14 |
+
"hidden_size": 4096,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"intermediate_size": 14336,
|
17 |
+
"max_position_embeddings": 32768,
|
18 |
+
"model_type": "evomistral",
|
19 |
+
"num_attention_heads": 32,
|
20 |
+
"num_hidden_layers": 44,
|
21 |
+
"num_hops": 65,
|
22 |
+
"num_key_value_heads": 8,
|
23 |
+
"rms_norm_eps": 1e-05,
|
24 |
+
"rope_theta": 10000.0,
|
25 |
+
"sliding_window": 4096,
|
26 |
+
"tie_word_embeddings": false,
|
27 |
+
"torch_dtype": "float32",
|
28 |
+
"transformers_version": "4.38.2",
|
29 |
+
"use_cache": false,
|
30 |
+
"vocab_size": 32000
|
31 |
+
}
|
configuration_evomistral.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.mistral.configuration_mistral import MistralConfig
|
2 |
+
|
3 |
+
|
4 |
+
class EvoMistralConfig(MistralConfig):
|
5 |
+
model_type = "evomistral"
|
6 |
+
|
7 |
+
def __init__(self, num_hops: int = 64, **kwargs):
|
8 |
+
self.num_hops = num_hops
|
9 |
+
super().__init__(**kwargs)
|
generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 1,
|
3 |
+
"eos_token_id": 2,
|
4 |
+
"max_new_tokens": 1024,
|
5 |
+
"pad_token_id": 2,
|
6 |
+
"transformers_version": "4.38.2"
|
7 |
+
}
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0eb2cb18f8988039b650ffb45521542e37370d99db71825a3d48b289376e26ef
|
3 |
+
size 4943163008
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df681d4220f14d4a77a6af7b069e59eadee4b48703695f016f8f7bf63b9a9442
|
3 |
+
size 4999819336
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8f1f60e6243cd8dfa1d43777c2e0c56dd7bbe2e2078e40d29121b98d29de53d
|
3 |
+
size 4915916184
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0be58a55d55d6560d2f924ae563d3e1cc32a83c0295f3be82230040bfd88aa3f
|
3 |
+
size 4859300760
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 19718152712
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"lm_head.weight": "model-00004-of-00004.safetensors",
|
7 |
+
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
8 |
+
"model.input_layers": "model-00001-of-00004.safetensors",
|
9 |
+
"model.input_scales": "model-00001-of-00004.safetensors",
|
10 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
11 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
12 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
13 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
14 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
15 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
16 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
17 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
18 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
19 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
20 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
21 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
22 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
23 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
24 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
25 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
26 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
27 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
28 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
29 |
+
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
30 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
31 |
+
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
32 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
33 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
34 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
35 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
36 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
37 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
38 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
39 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
40 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
41 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
42 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
43 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
44 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
45 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
46 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
47 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
48 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
49 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
50 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
51 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
52 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
53 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
54 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
55 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
56 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
57 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
58 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
59 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
60 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
61 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
62 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
63 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
64 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
65 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
66 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
67 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
68 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
69 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
70 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
71 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
72 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
73 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
74 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
75 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
76 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
77 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
78 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
79 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
80 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
81 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
82 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
83 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
84 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
85 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
86 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
87 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
88 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
89 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
90 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
91 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
92 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
93 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
94 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
95 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
96 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
97 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
98 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
99 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
100 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
101 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
102 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
103 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
104 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
105 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
106 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
107 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
108 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
109 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
110 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
111 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
112 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
113 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
114 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
115 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
116 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
117 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
118 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
119 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
120 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
121 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
122 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
123 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
124 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
125 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
126 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
127 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
128 |
+
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
129 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
130 |
+
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
131 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
132 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
133 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
134 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
135 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
136 |
+
"model.layers.21.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
137 |
+
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
138 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
139 |
+
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
140 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
141 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
142 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
143 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
144 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
145 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
146 |
+
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
147 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
148 |
+
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
149 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
150 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
151 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
152 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
153 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
154 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
155 |
+
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
156 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
157 |
+
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
158 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
159 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
160 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
161 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
162 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
163 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
164 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
165 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
166 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
167 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
168 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
169 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
170 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
171 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
172 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
173 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
174 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
175 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
176 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
177 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
178 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
179 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
180 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
181 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
182 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
183 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
184 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
185 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
186 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
187 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
188 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
189 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
190 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
191 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
192 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
193 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
194 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
195 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
196 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
197 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
198 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
199 |
+
"model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
200 |
+
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
201 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
202 |
+
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
203 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
204 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
205 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
206 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
207 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
208 |
+
"model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
209 |
+
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
210 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
211 |
+
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
212 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
213 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
214 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
215 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
216 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
217 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
218 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
219 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
220 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
221 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
222 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
223 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
224 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
225 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
226 |
+
"model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
227 |
+
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
228 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
229 |
+
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
230 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
231 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
232 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
233 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
234 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
235 |
+
"model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
236 |
+
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
237 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
238 |
+
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
239 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
240 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
241 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
242 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
243 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
244 |
+
"model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
245 |
+
"model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
246 |
+
"model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
247 |
+
"model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
248 |
+
"model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
249 |
+
"model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
250 |
+
"model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
251 |
+
"model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
252 |
+
"model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
253 |
+
"model.layers.33.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
254 |
+
"model.layers.33.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
255 |
+
"model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
256 |
+
"model.layers.33.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
257 |
+
"model.layers.33.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
258 |
+
"model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
259 |
+
"model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
260 |
+
"model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
261 |
+
"model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
262 |
+
"model.layers.34.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
263 |
+
"model.layers.34.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
264 |
+
"model.layers.34.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
265 |
+
"model.layers.34.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
266 |
+
"model.layers.34.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
267 |
+
"model.layers.34.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
268 |
+
"model.layers.34.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
269 |
+
"model.layers.34.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
270 |
+
"model.layers.34.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
271 |
+
"model.layers.35.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
272 |
+
"model.layers.35.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
273 |
+
"model.layers.35.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
274 |
+
"model.layers.35.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
275 |
+
"model.layers.35.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
276 |
+
"model.layers.35.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
277 |
+
"model.layers.35.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
278 |
+
"model.layers.35.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
279 |
+
"model.layers.35.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
280 |
+
"model.layers.36.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
281 |
+
"model.layers.36.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
282 |
+
"model.layers.36.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
283 |
+
"model.layers.36.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
284 |
+
"model.layers.36.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
285 |
+
"model.layers.36.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
286 |
+
"model.layers.36.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
287 |
+
"model.layers.36.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
288 |
+
"model.layers.36.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
289 |
+
"model.layers.37.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
290 |
+
"model.layers.37.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
291 |
+
"model.layers.37.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
292 |
+
"model.layers.37.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
293 |
+
"model.layers.37.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
294 |
+
"model.layers.37.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
295 |
+
"model.layers.37.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
296 |
+
"model.layers.37.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
297 |
+
"model.layers.37.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
298 |
+
"model.layers.38.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
299 |
+
"model.layers.38.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
300 |
+
"model.layers.38.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
301 |
+
"model.layers.38.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
302 |
+
"model.layers.38.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
303 |
+
"model.layers.38.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
304 |
+
"model.layers.38.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
305 |
+
"model.layers.38.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
306 |
+
"model.layers.38.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
307 |
+
"model.layers.39.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
308 |
+
"model.layers.39.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
309 |
+
"model.layers.39.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
310 |
+
"model.layers.39.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
311 |
+
"model.layers.39.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
312 |
+
"model.layers.39.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
313 |
+
"model.layers.39.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
314 |
+
"model.layers.39.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
315 |
+
"model.layers.39.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
316 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
317 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
318 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
319 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
320 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
321 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
322 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
323 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
324 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
325 |
+
"model.layers.40.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
326 |
+
"model.layers.40.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
327 |
+
"model.layers.40.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
328 |
+
"model.layers.40.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
329 |
+
"model.layers.40.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
330 |
+
"model.layers.40.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
331 |
+
"model.layers.40.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
332 |
+
"model.layers.40.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
333 |
+
"model.layers.40.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
334 |
+
"model.layers.41.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
335 |
+
"model.layers.41.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
336 |
+
"model.layers.41.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
337 |
+
"model.layers.41.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
338 |
+
"model.layers.41.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
339 |
+
"model.layers.41.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
340 |
+
"model.layers.41.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
341 |
+
"model.layers.41.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
342 |
+
"model.layers.41.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
343 |
+
"model.layers.42.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
344 |
+
"model.layers.42.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
345 |
+
"model.layers.42.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
346 |
+
"model.layers.42.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
347 |
+
"model.layers.42.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
348 |
+
"model.layers.42.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
349 |
+
"model.layers.42.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
350 |
+
"model.layers.42.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
351 |
+
"model.layers.42.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
352 |
+
"model.layers.43.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
353 |
+
"model.layers.43.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
354 |
+
"model.layers.43.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
355 |
+
"model.layers.43.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
356 |
+
"model.layers.43.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
357 |
+
"model.layers.43.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
358 |
+
"model.layers.43.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
359 |
+
"model.layers.43.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
360 |
+
"model.layers.43.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
361 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
362 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
363 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
364 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
365 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
366 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
367 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
368 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
369 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
370 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
371 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
372 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
373 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
374 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
375 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
376 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
377 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
378 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
379 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
380 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
381 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
382 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
383 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
384 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
385 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
386 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
387 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
388 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
389 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
390 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
391 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
392 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
393 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
394 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
395 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
396 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
397 |
+
"model.layers.9.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
398 |
+
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
399 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
400 |
+
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
401 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
402 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
403 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
404 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
405 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
406 |
+
"model.norm.weight": "model-00004-of-00004.safetensors"
|
407 |
+
}
|
408 |
+
}
|
modeling_evomistral.py
ADDED
@@ -0,0 +1,1379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" PyTorch Mistral model."""
|
21 |
+
import inspect
|
22 |
+
import math
|
23 |
+
import warnings
|
24 |
+
from typing import List, Optional, Tuple, Union
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import torch.utils.checkpoint
|
29 |
+
from torch import nn
|
30 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
+
|
32 |
+
from transformers.activations import ACT2FN
|
33 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
34 |
+
from transformers.modeling_outputs import (
|
35 |
+
BaseModelOutputWithPast,
|
36 |
+
CausalLMOutputWithPast,
|
37 |
+
SequenceClassifierOutputWithPast,
|
38 |
+
)
|
39 |
+
from transformers.modeling_utils import PreTrainedModel
|
40 |
+
from transformers.utils import (
|
41 |
+
add_start_docstrings,
|
42 |
+
add_start_docstrings_to_model_forward,
|
43 |
+
is_flash_attn_2_available,
|
44 |
+
logging,
|
45 |
+
replace_return_docstrings,
|
46 |
+
)
|
47 |
+
from .configuration_evomistral import EvoMistralConfig
|
48 |
+
|
49 |
+
if is_flash_attn_2_available():
|
50 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
51 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
52 |
+
|
53 |
+
_flash_supports_window_size = "window_size" in list(
|
54 |
+
inspect.signature(flash_attn_func).parameters
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
logger = logging.get_logger(__name__)
|
59 |
+
|
60 |
+
_CONFIG_FOR_DOC = "MistralConfig"
|
61 |
+
|
62 |
+
|
63 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
64 |
+
def _get_unpad_data(attention_mask):
|
65 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
66 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
67 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
68 |
+
cu_seqlens = F.pad(
|
69 |
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
70 |
+
)
|
71 |
+
return (
|
72 |
+
indices,
|
73 |
+
cu_seqlens,
|
74 |
+
max_seqlen_in_batch,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
|
79 |
+
class MistralRMSNorm(nn.Module):
|
80 |
+
def __init__(self, hidden_size, eps=1e-6):
|
81 |
+
"""
|
82 |
+
MistralRMSNorm is equivalent to T5LayerNorm
|
83 |
+
"""
|
84 |
+
super().__init__()
|
85 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
86 |
+
self.variance_epsilon = eps
|
87 |
+
|
88 |
+
def forward(self, hidden_states, residual=None):
|
89 |
+
input_dtype = hidden_states.dtype
|
90 |
+
hidden_states = hidden_states.to(torch.float32)
|
91 |
+
if residual is not None:
|
92 |
+
hidden_states = hidden_states + residual.to(torch.float32)
|
93 |
+
residual = hidden_states.to(input_dtype)
|
94 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
95 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
96 |
+
hidden_states = self.weight * hidden_states.to(input_dtype)
|
97 |
+
if residual is None:
|
98 |
+
return hidden_states
|
99 |
+
else:
|
100 |
+
return hidden_states, residual
|
101 |
+
|
102 |
+
|
103 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
104 |
+
class MistralRotaryEmbedding(nn.Module):
|
105 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.dim = dim
|
109 |
+
self.max_position_embeddings = max_position_embeddings
|
110 |
+
self.base = base
|
111 |
+
inv_freq = 1.0 / (
|
112 |
+
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
113 |
+
)
|
114 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
115 |
+
|
116 |
+
# Build here to make `torch.jit.trace` work.
|
117 |
+
self._set_cos_sin_cache(
|
118 |
+
seq_len=max_position_embeddings,
|
119 |
+
device=self.inv_freq.device,
|
120 |
+
dtype=torch.get_default_dtype(),
|
121 |
+
)
|
122 |
+
|
123 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
124 |
+
self.max_seq_len_cached = seq_len
|
125 |
+
t = torch.arange(
|
126 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
127 |
+
)
|
128 |
+
|
129 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
130 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
131 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
132 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
133 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
134 |
+
|
135 |
+
def forward(self, x, seq_len=None):
|
136 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
137 |
+
if seq_len > self.max_seq_len_cached:
|
138 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
139 |
+
|
140 |
+
return (
|
141 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
142 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
147 |
+
def rotate_half(x):
|
148 |
+
"""Rotates half the hidden dims of the input."""
|
149 |
+
x1 = x[..., : x.shape[-1] // 2]
|
150 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
151 |
+
return torch.cat((-x2, x1), dim=-1)
|
152 |
+
|
153 |
+
|
154 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
155 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
156 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
q (`torch.Tensor`): The query tensor.
|
160 |
+
k (`torch.Tensor`): The key tensor.
|
161 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
162 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
163 |
+
position_ids (`torch.Tensor`):
|
164 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
165 |
+
used to pass offsetted position ids when working with a KV-cache.
|
166 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
167 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
168 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
169 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
170 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
171 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
172 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
173 |
+
Returns:
|
174 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
175 |
+
"""
|
176 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
177 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
178 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
179 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
180 |
+
return q_embed, k_embed
|
181 |
+
|
182 |
+
|
183 |
+
class MistralMLP(nn.Module):
|
184 |
+
def __init__(self, config):
|
185 |
+
super().__init__()
|
186 |
+
self.config = config
|
187 |
+
self.hidden_size = config.hidden_size
|
188 |
+
self.intermediate_size = config.intermediate_size
|
189 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
190 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
191 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
192 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
196 |
+
|
197 |
+
|
198 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
199 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
200 |
+
"""
|
201 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
202 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
203 |
+
"""
|
204 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
205 |
+
if n_rep == 1:
|
206 |
+
return hidden_states
|
207 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
208 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
209 |
+
)
|
210 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
211 |
+
|
212 |
+
|
213 |
+
class MistralAttention(nn.Module):
|
214 |
+
"""
|
215 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
216 |
+
and "Generating Long Sequences with Sparse Transformers".
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, config: EvoMistralConfig):
|
220 |
+
super().__init__()
|
221 |
+
self.config = config
|
222 |
+
self.hidden_size = config.hidden_size
|
223 |
+
self.num_heads = config.num_attention_heads
|
224 |
+
self.head_dim = self.hidden_size // self.num_heads
|
225 |
+
self.num_key_value_heads = config.num_key_value_heads
|
226 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
227 |
+
self.max_position_embeddings = config.max_position_embeddings
|
228 |
+
self.rope_theta = config.rope_theta
|
229 |
+
self.is_causal = True
|
230 |
+
|
231 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
232 |
+
raise ValueError(
|
233 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
234 |
+
f" and `num_heads`: {self.num_heads})."
|
235 |
+
)
|
236 |
+
self.q_proj = nn.Linear(
|
237 |
+
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
238 |
+
)
|
239 |
+
self.k_proj = nn.Linear(
|
240 |
+
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
241 |
+
)
|
242 |
+
self.v_proj = nn.Linear(
|
243 |
+
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
244 |
+
)
|
245 |
+
self.o_proj = nn.Linear(
|
246 |
+
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
247 |
+
)
|
248 |
+
|
249 |
+
self.rotary_emb = MistralRotaryEmbedding(
|
250 |
+
self.head_dim,
|
251 |
+
max_position_embeddings=self.max_position_embeddings,
|
252 |
+
base=self.rope_theta,
|
253 |
+
)
|
254 |
+
|
255 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
256 |
+
return (
|
257 |
+
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
258 |
+
.transpose(1, 2)
|
259 |
+
.contiguous()
|
260 |
+
)
|
261 |
+
|
262 |
+
def forward(
|
263 |
+
self,
|
264 |
+
hidden_states: torch.Tensor,
|
265 |
+
attention_mask: Optional[torch.Tensor] = None,
|
266 |
+
position_ids: Optional[torch.LongTensor] = None,
|
267 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
268 |
+
output_attentions: bool = False,
|
269 |
+
use_cache: bool = False,
|
270 |
+
**kwargs,
|
271 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
272 |
+
if "padding_mask" in kwargs:
|
273 |
+
warnings.warn(
|
274 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
275 |
+
)
|
276 |
+
bsz, q_len, _ = hidden_states.size()
|
277 |
+
|
278 |
+
query_states = self.q_proj(hidden_states)
|
279 |
+
key_states = self.k_proj(hidden_states)
|
280 |
+
value_states = self.v_proj(hidden_states)
|
281 |
+
|
282 |
+
query_states = query_states.view(
|
283 |
+
bsz, q_len, self.num_heads, self.head_dim
|
284 |
+
).transpose(1, 2)
|
285 |
+
key_states = key_states.view(
|
286 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
287 |
+
).transpose(1, 2)
|
288 |
+
value_states = value_states.view(
|
289 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
290 |
+
).transpose(1, 2)
|
291 |
+
|
292 |
+
kv_seq_len = key_states.shape[-2]
|
293 |
+
if past_key_value is not None:
|
294 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
295 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
296 |
+
query_states, key_states = apply_rotary_pos_emb(
|
297 |
+
query_states, key_states, cos, sin, position_ids
|
298 |
+
)
|
299 |
+
|
300 |
+
if past_key_value is not None:
|
301 |
+
# reuse k, v, self_attention
|
302 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
303 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
304 |
+
|
305 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
306 |
+
|
307 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
308 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
309 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
310 |
+
|
311 |
+
attn_weights = torch.matmul(
|
312 |
+
query_states, key_states.transpose(2, 3)
|
313 |
+
) / math.sqrt(self.head_dim)
|
314 |
+
|
315 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
316 |
+
raise ValueError(
|
317 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
318 |
+
f" {attn_weights.size()}"
|
319 |
+
)
|
320 |
+
|
321 |
+
if attention_mask is not None:
|
322 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
323 |
+
raise ValueError(
|
324 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
325 |
+
)
|
326 |
+
|
327 |
+
attn_weights = attn_weights + attention_mask
|
328 |
+
|
329 |
+
# upcast attention to fp32
|
330 |
+
attn_weights = nn.functional.softmax(
|
331 |
+
attn_weights, dim=-1, dtype=torch.float32
|
332 |
+
).to(query_states.dtype)
|
333 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
334 |
+
|
335 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
336 |
+
raise ValueError(
|
337 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
338 |
+
f" {attn_output.size()}"
|
339 |
+
)
|
340 |
+
|
341 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
342 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
343 |
+
|
344 |
+
attn_output = self.o_proj(attn_output)
|
345 |
+
|
346 |
+
if not output_attentions:
|
347 |
+
attn_weights = None
|
348 |
+
|
349 |
+
return attn_output, attn_weights, past_key_value
|
350 |
+
|
351 |
+
|
352 |
+
class MistralFlashAttention2(MistralAttention):
|
353 |
+
"""
|
354 |
+
Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
|
355 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
356 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
357 |
+
"""
|
358 |
+
|
359 |
+
def forward(
|
360 |
+
self,
|
361 |
+
hidden_states: torch.Tensor,
|
362 |
+
attention_mask: Optional[torch.Tensor] = None,
|
363 |
+
position_ids: Optional[torch.LongTensor] = None,
|
364 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
365 |
+
output_attentions: bool = False,
|
366 |
+
use_cache: bool = False,
|
367 |
+
**kwargs,
|
368 |
+
):
|
369 |
+
if "padding_mask" in kwargs:
|
370 |
+
warnings.warn(
|
371 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
372 |
+
)
|
373 |
+
|
374 |
+
# overwrite attention_mask with padding_mask
|
375 |
+
attention_mask = kwargs.pop("padding_mask")
|
376 |
+
bsz, q_len, _ = hidden_states.size()
|
377 |
+
|
378 |
+
query_states = self.q_proj(hidden_states)
|
379 |
+
key_states = self.k_proj(hidden_states)
|
380 |
+
value_states = self.v_proj(hidden_states)
|
381 |
+
|
382 |
+
query_states = query_states.view(
|
383 |
+
bsz, q_len, self.num_heads, self.head_dim
|
384 |
+
).transpose(1, 2)
|
385 |
+
key_states = key_states.view(
|
386 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
387 |
+
).transpose(1, 2)
|
388 |
+
value_states = value_states.view(
|
389 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
390 |
+
).transpose(1, 2)
|
391 |
+
|
392 |
+
kv_seq_len = key_states.shape[-2]
|
393 |
+
if past_key_value is not None:
|
394 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
395 |
+
|
396 |
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
397 |
+
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
398 |
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
399 |
+
|
400 |
+
query_states, key_states = apply_rotary_pos_emb(
|
401 |
+
query_states, key_states, cos, sin, position_ids
|
402 |
+
)
|
403 |
+
|
404 |
+
use_sliding_windows = (
|
405 |
+
_flash_supports_window_size
|
406 |
+
and hasattr(self.config, "sliding_window") is not None
|
407 |
+
and kv_seq_len > self.config.sliding_window
|
408 |
+
)
|
409 |
+
|
410 |
+
if not _flash_supports_window_size:
|
411 |
+
logger.warning_once(
|
412 |
+
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
413 |
+
" make sure to upgrade flash-attn library."
|
414 |
+
)
|
415 |
+
|
416 |
+
if past_key_value is not None:
|
417 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
418 |
+
if (
|
419 |
+
hasattr(self.config, "sliding_window")
|
420 |
+
and kv_seq_len > self.config.sliding_window
|
421 |
+
):
|
422 |
+
slicing_tokens = kv_seq_len - self.config.sliding_window
|
423 |
+
|
424 |
+
past_key = past_key_value[0]
|
425 |
+
past_value = past_key_value[1]
|
426 |
+
|
427 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
428 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
429 |
+
|
430 |
+
if past_key.shape[-2] != self.config.sliding_window - 1:
|
431 |
+
raise ValueError(
|
432 |
+
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
433 |
+
f" {past_key.shape}"
|
434 |
+
)
|
435 |
+
|
436 |
+
past_key_value = (past_key, past_value)
|
437 |
+
|
438 |
+
if attention_mask is not None:
|
439 |
+
attention_mask = attention_mask[:, slicing_tokens:]
|
440 |
+
attention_mask = torch.cat(
|
441 |
+
[attention_mask, torch.ones_like(attention_mask[:, -1:])],
|
442 |
+
dim=-1,
|
443 |
+
)
|
444 |
+
|
445 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
446 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
447 |
+
|
448 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
449 |
+
|
450 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
451 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
452 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
453 |
+
|
454 |
+
# TODO: Mistral does not have dropout in the config??
|
455 |
+
# It is recommended to use dropout with FA according to the docs
|
456 |
+
# when training.
|
457 |
+
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
458 |
+
|
459 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
460 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
461 |
+
# cast them back in float16 just to be sure everything works as expected.
|
462 |
+
input_dtype = query_states.dtype
|
463 |
+
if input_dtype == torch.float32:
|
464 |
+
# Handle the case where the model is quantized
|
465 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
466 |
+
target_dtype = self.config._pre_quantization_dtype
|
467 |
+
else:
|
468 |
+
target_dtype = self.q_proj.weight.dtype
|
469 |
+
|
470 |
+
logger.warning_once(
|
471 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
472 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
473 |
+
f" {target_dtype}."
|
474 |
+
)
|
475 |
+
|
476 |
+
query_states = query_states.to(target_dtype)
|
477 |
+
key_states = key_states.to(target_dtype)
|
478 |
+
value_states = value_states.to(target_dtype)
|
479 |
+
|
480 |
+
# Reashape to the expected shape for Flash Attention
|
481 |
+
query_states = query_states.transpose(1, 2)
|
482 |
+
key_states = key_states.transpose(1, 2)
|
483 |
+
value_states = value_states.transpose(1, 2)
|
484 |
+
|
485 |
+
attn_output = self._flash_attention_forward(
|
486 |
+
query_states,
|
487 |
+
key_states,
|
488 |
+
value_states,
|
489 |
+
attention_mask,
|
490 |
+
q_len,
|
491 |
+
dropout=dropout_rate,
|
492 |
+
use_sliding_windows=use_sliding_windows,
|
493 |
+
)
|
494 |
+
|
495 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
496 |
+
attn_output = self.o_proj(attn_output)
|
497 |
+
|
498 |
+
if not output_attentions:
|
499 |
+
attn_weights = None
|
500 |
+
|
501 |
+
return attn_output, attn_weights, past_key_value
|
502 |
+
|
503 |
+
def _flash_attention_forward(
|
504 |
+
self,
|
505 |
+
query_states,
|
506 |
+
key_states,
|
507 |
+
value_states,
|
508 |
+
attention_mask,
|
509 |
+
query_length,
|
510 |
+
dropout=0.0,
|
511 |
+
softmax_scale=None,
|
512 |
+
use_sliding_windows=False,
|
513 |
+
):
|
514 |
+
"""
|
515 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
516 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
517 |
+
|
518 |
+
Args:
|
519 |
+
query_states (`torch.Tensor`):
|
520 |
+
Input query states to be passed to Flash Attention API
|
521 |
+
key_states (`torch.Tensor`):
|
522 |
+
Input key states to be passed to Flash Attention API
|
523 |
+
value_states (`torch.Tensor`):
|
524 |
+
Input value states to be passed to Flash Attention API
|
525 |
+
attention_mask (`torch.Tensor`):
|
526 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
527 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
528 |
+
dropout (`int`, *optional*):
|
529 |
+
Attention dropout
|
530 |
+
softmax_scale (`float`, *optional*):
|
531 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
532 |
+
use_sliding_windows (`bool`, *optional*):
|
533 |
+
Whether to activate sliding window attention.
|
534 |
+
"""
|
535 |
+
# Contains at least one padding token in the sequence
|
536 |
+
if attention_mask is not None:
|
537 |
+
batch_size = query_states.shape[0]
|
538 |
+
(
|
539 |
+
query_states,
|
540 |
+
key_states,
|
541 |
+
value_states,
|
542 |
+
indices_q,
|
543 |
+
cu_seq_lens,
|
544 |
+
max_seq_lens,
|
545 |
+
) = self._upad_input(
|
546 |
+
query_states, key_states, value_states, attention_mask, query_length
|
547 |
+
)
|
548 |
+
|
549 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
550 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
551 |
+
|
552 |
+
if not use_sliding_windows:
|
553 |
+
attn_output_unpad = flash_attn_varlen_func(
|
554 |
+
query_states,
|
555 |
+
key_states,
|
556 |
+
value_states,
|
557 |
+
cu_seqlens_q=cu_seqlens_q,
|
558 |
+
cu_seqlens_k=cu_seqlens_k,
|
559 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
560 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
561 |
+
dropout_p=dropout,
|
562 |
+
softmax_scale=softmax_scale,
|
563 |
+
causal=self.is_causal,
|
564 |
+
)
|
565 |
+
else:
|
566 |
+
attn_output_unpad = flash_attn_varlen_func(
|
567 |
+
query_states,
|
568 |
+
key_states,
|
569 |
+
value_states,
|
570 |
+
cu_seqlens_q=cu_seqlens_q,
|
571 |
+
cu_seqlens_k=cu_seqlens_k,
|
572 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
573 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
574 |
+
dropout_p=dropout,
|
575 |
+
softmax_scale=softmax_scale,
|
576 |
+
causal=self.is_causal,
|
577 |
+
window_size=(
|
578 |
+
self.config.sliding_window,
|
579 |
+
self.config.sliding_window,
|
580 |
+
),
|
581 |
+
)
|
582 |
+
|
583 |
+
attn_output = pad_input(
|
584 |
+
attn_output_unpad, indices_q, batch_size, query_length
|
585 |
+
)
|
586 |
+
else:
|
587 |
+
if not use_sliding_windows:
|
588 |
+
attn_output = flash_attn_func(
|
589 |
+
query_states,
|
590 |
+
key_states,
|
591 |
+
value_states,
|
592 |
+
dropout,
|
593 |
+
softmax_scale=softmax_scale,
|
594 |
+
causal=self.is_causal,
|
595 |
+
)
|
596 |
+
else:
|
597 |
+
attn_output = flash_attn_func(
|
598 |
+
query_states,
|
599 |
+
key_states,
|
600 |
+
value_states,
|
601 |
+
dropout,
|
602 |
+
softmax_scale=softmax_scale,
|
603 |
+
causal=self.is_causal,
|
604 |
+
window_size=(
|
605 |
+
self.config.sliding_window,
|
606 |
+
self.config.sliding_window,
|
607 |
+
),
|
608 |
+
)
|
609 |
+
|
610 |
+
return attn_output
|
611 |
+
|
612 |
+
def _upad_input(
|
613 |
+
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
614 |
+
):
|
615 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
616 |
+
|
617 |
+
# On the first iteration we need to properly re-create the padding mask
|
618 |
+
# by slicing it on the proper place
|
619 |
+
if kv_seq_len != attention_mask.shape[-1]:
|
620 |
+
attention_mask_num_tokens = attention_mask.shape[-1]
|
621 |
+
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
|
622 |
+
|
623 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
624 |
+
|
625 |
+
key_layer = index_first_axis(
|
626 |
+
key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
627 |
+
)
|
628 |
+
value_layer = index_first_axis(
|
629 |
+
value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
630 |
+
)
|
631 |
+
|
632 |
+
if query_length == kv_seq_len:
|
633 |
+
query_layer = index_first_axis(
|
634 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
|
635 |
+
indices_k,
|
636 |
+
)
|
637 |
+
cu_seqlens_q = cu_seqlens_k
|
638 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
639 |
+
indices_q = indices_k
|
640 |
+
elif query_length == 1:
|
641 |
+
max_seqlen_in_batch_q = 1
|
642 |
+
cu_seqlens_q = torch.arange(
|
643 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
644 |
+
) # There is a memcpy here, that is very bad.
|
645 |
+
indices_q = cu_seqlens_q[:-1]
|
646 |
+
query_layer = query_layer.squeeze(1)
|
647 |
+
else:
|
648 |
+
# The -q_len: slice assumes left padding.
|
649 |
+
attention_mask = attention_mask[:, -query_length:]
|
650 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
651 |
+
query_layer, attention_mask
|
652 |
+
)
|
653 |
+
|
654 |
+
return (
|
655 |
+
query_layer,
|
656 |
+
key_layer,
|
657 |
+
value_layer,
|
658 |
+
indices_q,
|
659 |
+
(cu_seqlens_q, cu_seqlens_k),
|
660 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
661 |
+
)
|
662 |
+
|
663 |
+
|
664 |
+
class MistralDecoderLayer(nn.Module):
|
665 |
+
def __init__(self, config: EvoMistralConfig):
|
666 |
+
super().__init__()
|
667 |
+
self.hidden_size = config.hidden_size
|
668 |
+
self.self_attn = (
|
669 |
+
MistralAttention(config=config)
|
670 |
+
if not getattr(config, "_flash_attn_2_enabled", False)
|
671 |
+
else MistralFlashAttention2(config)
|
672 |
+
)
|
673 |
+
self.mlp = MistralMLP(config)
|
674 |
+
self.input_layernorm = MistralRMSNorm(
|
675 |
+
config.hidden_size, eps=config.rms_norm_eps
|
676 |
+
)
|
677 |
+
self.post_attention_layernorm = MistralRMSNorm(
|
678 |
+
config.hidden_size, eps=config.rms_norm_eps
|
679 |
+
)
|
680 |
+
|
681 |
+
def forward(
|
682 |
+
self,
|
683 |
+
hidden_states: torch.Tensor,
|
684 |
+
attention_mask: Optional[torch.Tensor] = None,
|
685 |
+
position_ids: Optional[torch.LongTensor] = None,
|
686 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
687 |
+
output_attentions: Optional[bool] = False,
|
688 |
+
use_cache: Optional[bool] = False,
|
689 |
+
residual: Optional[torch.Tensor] = None,
|
690 |
+
**kwargs,
|
691 |
+
) -> Tuple[
|
692 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
693 |
+
]:
|
694 |
+
if "padding_mask" in kwargs:
|
695 |
+
warnings.warn(
|
696 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
697 |
+
)
|
698 |
+
"""
|
699 |
+
Args:
|
700 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
701 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
702 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
703 |
+
output_attentions (`bool`, *optional*):
|
704 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
705 |
+
returned tensors for more detail.
|
706 |
+
use_cache (`bool`, *optional*):
|
707 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
708 |
+
(see `past_key_values`).
|
709 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
710 |
+
"""
|
711 |
+
if residual is None:
|
712 |
+
residual = hidden_states
|
713 |
+
hidden_states = self.input_layernorm(hidden_states)
|
714 |
+
else:
|
715 |
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
716 |
+
|
717 |
+
# Self Attention
|
718 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
719 |
+
hidden_states=hidden_states,
|
720 |
+
attention_mask=attention_mask,
|
721 |
+
position_ids=position_ids,
|
722 |
+
past_key_value=past_key_value,
|
723 |
+
output_attentions=output_attentions,
|
724 |
+
use_cache=use_cache,
|
725 |
+
)
|
726 |
+
|
727 |
+
# Fully Connected
|
728 |
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
729 |
+
hidden_states = self.mlp(hidden_states)
|
730 |
+
|
731 |
+
outputs = ((hidden_states, residual),)
|
732 |
+
|
733 |
+
if output_attentions:
|
734 |
+
outputs += (self_attn_weights,)
|
735 |
+
|
736 |
+
if use_cache:
|
737 |
+
outputs += (present_key_value,)
|
738 |
+
|
739 |
+
return outputs
|
740 |
+
|
741 |
+
|
742 |
+
MISTRAL_START_DOCSTRING = r"""
|
743 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
744 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
745 |
+
etc.)
|
746 |
+
|
747 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
748 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
749 |
+
and behavior.
|
750 |
+
|
751 |
+
Parameters:
|
752 |
+
config ([`MistralConfig`]):
|
753 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
754 |
+
load the weights associated with the model, only the configuration. Check out the
|
755 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
756 |
+
"""
|
757 |
+
|
758 |
+
|
759 |
+
@add_start_docstrings(
|
760 |
+
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
|
761 |
+
MISTRAL_START_DOCSTRING,
|
762 |
+
)
|
763 |
+
class MistralPreTrainedModel(PreTrainedModel):
|
764 |
+
config_class = EvoMistralConfig
|
765 |
+
base_model_prefix = "model"
|
766 |
+
supports_gradient_checkpointing = True
|
767 |
+
_no_split_modules = ["MistralDecoderLayer"]
|
768 |
+
_skip_keys_device_placement = "past_key_values"
|
769 |
+
_supports_flash_attn_2 = True
|
770 |
+
|
771 |
+
def _init_weights(self, module):
|
772 |
+
std = self.config.initializer_range
|
773 |
+
if isinstance(module, nn.Linear):
|
774 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
775 |
+
if module.bias is not None:
|
776 |
+
module.bias.data.zero_()
|
777 |
+
elif isinstance(module, nn.Embedding):
|
778 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
779 |
+
if module.padding_idx is not None:
|
780 |
+
module.weight.data[module.padding_idx].zero_()
|
781 |
+
|
782 |
+
|
783 |
+
MISTRAL_INPUTS_DOCSTRING = r"""
|
784 |
+
Args:
|
785 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
786 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
787 |
+
it.
|
788 |
+
|
789 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
790 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
791 |
+
|
792 |
+
[What are input IDs?](../glossary#input-ids)
|
793 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
794 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
795 |
+
|
796 |
+
- 1 for tokens that are **not masked**,
|
797 |
+
- 0 for tokens that are **masked**.
|
798 |
+
|
799 |
+
[What are attention masks?](../glossary#attention-mask)
|
800 |
+
|
801 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
802 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
803 |
+
|
804 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
805 |
+
`past_key_values`).
|
806 |
+
|
807 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
808 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
809 |
+
information on the default strategy.
|
810 |
+
|
811 |
+
- 1 indicates the head is **not masked**,
|
812 |
+
- 0 indicates the head is **masked**.
|
813 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
814 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
815 |
+
config.n_positions - 1]`.
|
816 |
+
|
817 |
+
[What are position IDs?](../glossary#position-ids)
|
818 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
819 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
820 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
821 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
822 |
+
|
823 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
824 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
825 |
+
|
826 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
827 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
828 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
829 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
830 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
831 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
832 |
+
model's internal embedding lookup matrix.
|
833 |
+
use_cache (`bool`, *optional*):
|
834 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
835 |
+
`past_key_values`).
|
836 |
+
output_attentions (`bool`, *optional*):
|
837 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
838 |
+
tensors for more detail.
|
839 |
+
output_hidden_states (`bool`, *optional*):
|
840 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
841 |
+
more detail.
|
842 |
+
return_dict (`bool`, *optional*):
|
843 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
844 |
+
"""
|
845 |
+
|
846 |
+
|
847 |
+
@add_start_docstrings(
|
848 |
+
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
|
849 |
+
MISTRAL_START_DOCSTRING,
|
850 |
+
)
|
851 |
+
class EvoMistralModel(MistralPreTrainedModel):
|
852 |
+
"""
|
853 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
|
854 |
+
|
855 |
+
Args:
|
856 |
+
config: MistralConfig
|
857 |
+
"""
|
858 |
+
|
859 |
+
def __init__(self, config: EvoMistralConfig):
|
860 |
+
super().__init__(config)
|
861 |
+
self.padding_idx = config.pad_token_id
|
862 |
+
self.vocab_size = config.vocab_size
|
863 |
+
|
864 |
+
self.embed_tokens = nn.Embedding(
|
865 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
866 |
+
)
|
867 |
+
self.layers = nn.ModuleList(
|
868 |
+
[MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
869 |
+
)
|
870 |
+
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
871 |
+
|
872 |
+
self.input_scales = nn.Parameter(
|
873 |
+
data=torch.zeros(self.config.num_hops).float(), requires_grad=False
|
874 |
+
)
|
875 |
+
self.input_layers = nn.Parameter(
|
876 |
+
data=torch.zeros(self.config.num_hops).int(), requires_grad=False
|
877 |
+
)
|
878 |
+
|
879 |
+
self.gradient_checkpointing = False
|
880 |
+
# Initialize weights and apply final processing
|
881 |
+
self.post_init()
|
882 |
+
|
883 |
+
def get_input_embeddings(self):
|
884 |
+
return self.embed_tokens
|
885 |
+
|
886 |
+
def set_input_embeddings(self, value):
|
887 |
+
self.embed_tokens = value
|
888 |
+
|
889 |
+
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
890 |
+
def forward(
|
891 |
+
self,
|
892 |
+
input_ids: torch.LongTensor = None,
|
893 |
+
attention_mask: Optional[torch.Tensor] = None,
|
894 |
+
position_ids: Optional[torch.LongTensor] = None,
|
895 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
896 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
897 |
+
use_cache: Optional[bool] = None,
|
898 |
+
output_attentions: Optional[bool] = None,
|
899 |
+
output_hidden_states: Optional[bool] = None,
|
900 |
+
return_dict: Optional[bool] = None,
|
901 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
902 |
+
output_attentions = (
|
903 |
+
output_attentions
|
904 |
+
if output_attentions is not None
|
905 |
+
else self.config.output_attentions
|
906 |
+
)
|
907 |
+
output_hidden_states = (
|
908 |
+
output_hidden_states
|
909 |
+
if output_hidden_states is not None
|
910 |
+
else self.config.output_hidden_states
|
911 |
+
)
|
912 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
913 |
+
|
914 |
+
return_dict = (
|
915 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
916 |
+
)
|
917 |
+
|
918 |
+
# retrieve input_ids and inputs_embeds
|
919 |
+
if input_ids is not None and inputs_embeds is not None:
|
920 |
+
raise ValueError(
|
921 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
922 |
+
)
|
923 |
+
elif input_ids is not None:
|
924 |
+
batch_size, seq_length = input_ids.shape
|
925 |
+
elif inputs_embeds is not None:
|
926 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
927 |
+
else:
|
928 |
+
raise ValueError(
|
929 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
930 |
+
)
|
931 |
+
|
932 |
+
seq_length_with_past = seq_length
|
933 |
+
past_key_values_length = 0
|
934 |
+
|
935 |
+
if past_key_values is not None:
|
936 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
937 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
938 |
+
|
939 |
+
if position_ids is None:
|
940 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
941 |
+
position_ids = torch.arange(
|
942 |
+
past_key_values_length,
|
943 |
+
seq_length + past_key_values_length,
|
944 |
+
dtype=torch.long,
|
945 |
+
device=device,
|
946 |
+
)
|
947 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
948 |
+
else:
|
949 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
950 |
+
|
951 |
+
if inputs_embeds is None:
|
952 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
953 |
+
|
954 |
+
if (
|
955 |
+
attention_mask is not None
|
956 |
+
and hasattr(self.config, "_flash_attn_2_enabled")
|
957 |
+
and self.config._flash_attn_2_enabled
|
958 |
+
and past_key_values is not None
|
959 |
+
):
|
960 |
+
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
961 |
+
if is_padding_right:
|
962 |
+
raise ValueError(
|
963 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
964 |
+
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
965 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
966 |
+
)
|
967 |
+
|
968 |
+
if getattr(self.config, "_flash_attn_2_enabled", False):
|
969 |
+
# 2d mask is passed through the layers
|
970 |
+
attention_mask = (
|
971 |
+
attention_mask
|
972 |
+
if (attention_mask is not None and 0 in attention_mask)
|
973 |
+
else None
|
974 |
+
)
|
975 |
+
else:
|
976 |
+
# 4d mask is passed through the layers
|
977 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
978 |
+
attention_mask,
|
979 |
+
(batch_size, seq_length),
|
980 |
+
inputs_embeds,
|
981 |
+
past_key_values_length,
|
982 |
+
sliding_window=self.config.sliding_window,
|
983 |
+
)
|
984 |
+
|
985 |
+
hidden_states = inputs_embeds
|
986 |
+
|
987 |
+
if self.gradient_checkpointing and self.training:
|
988 |
+
if use_cache:
|
989 |
+
logger.warning_once(
|
990 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
991 |
+
)
|
992 |
+
use_cache = False
|
993 |
+
|
994 |
+
# decoder layers
|
995 |
+
all_hidden_states = () if output_hidden_states else None
|
996 |
+
all_self_attns = () if output_attentions else None
|
997 |
+
next_decoder_cache = () if use_cache else None
|
998 |
+
residual = None
|
999 |
+
|
1000 |
+
for idx, layer_ix in enumerate(self.input_layers):
|
1001 |
+
decoder_layer = self.layers[layer_ix]
|
1002 |
+
scale = self.input_scales[idx].to(hidden_states.device)
|
1003 |
+
|
1004 |
+
if output_hidden_states:
|
1005 |
+
all_hidden_states += (hidden_states,)
|
1006 |
+
|
1007 |
+
past_key_value = (
|
1008 |
+
past_key_values[idx] if past_key_values is not None else None
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
if self.gradient_checkpointing and self.training:
|
1012 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1013 |
+
decoder_layer.__call__,
|
1014 |
+
hidden_states * scale,
|
1015 |
+
attention_mask,
|
1016 |
+
position_ids,
|
1017 |
+
past_key_value,
|
1018 |
+
output_attentions,
|
1019 |
+
use_cache,
|
1020 |
+
residual,
|
1021 |
+
)
|
1022 |
+
else:
|
1023 |
+
layer_outputs = decoder_layer(
|
1024 |
+
hidden_states * scale,
|
1025 |
+
attention_mask=attention_mask,
|
1026 |
+
position_ids=position_ids,
|
1027 |
+
past_key_value=past_key_value,
|
1028 |
+
output_attentions=output_attentions,
|
1029 |
+
use_cache=use_cache,
|
1030 |
+
residual=residual,
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
hidden_states, residual = layer_outputs[0]
|
1034 |
+
|
1035 |
+
if use_cache:
|
1036 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
1037 |
+
|
1038 |
+
if output_attentions:
|
1039 |
+
all_self_attns += (layer_outputs[1],)
|
1040 |
+
|
1041 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
1042 |
+
|
1043 |
+
# add hidden states from the last decoder layer
|
1044 |
+
if output_hidden_states:
|
1045 |
+
all_hidden_states += (hidden_states,)
|
1046 |
+
|
1047 |
+
next_cache = next_decoder_cache if use_cache else None
|
1048 |
+
if not return_dict:
|
1049 |
+
return tuple(
|
1050 |
+
v
|
1051 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
1052 |
+
if v is not None
|
1053 |
+
)
|
1054 |
+
return BaseModelOutputWithPast(
|
1055 |
+
last_hidden_state=hidden_states,
|
1056 |
+
past_key_values=next_cache,
|
1057 |
+
hidden_states=all_hidden_states,
|
1058 |
+
attentions=all_self_attns,
|
1059 |
+
)
|
1060 |
+
|
1061 |
+
|
1062 |
+
class EvoMistralForCausalLM(MistralPreTrainedModel):
|
1063 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1064 |
+
|
1065 |
+
def __init__(self, config):
|
1066 |
+
super().__init__(config)
|
1067 |
+
self.model = EvoMistralModel(config)
|
1068 |
+
self.vocab_size = config.vocab_size
|
1069 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1070 |
+
|
1071 |
+
# Initialize weights and apply final processing
|
1072 |
+
self.post_init()
|
1073 |
+
|
1074 |
+
def get_input_embeddings(self):
|
1075 |
+
return self.model.embed_tokens
|
1076 |
+
|
1077 |
+
def set_input_embeddings(self, value):
|
1078 |
+
self.model.embed_tokens = value
|
1079 |
+
|
1080 |
+
def get_output_embeddings(self):
|
1081 |
+
return self.lm_head
|
1082 |
+
|
1083 |
+
def set_output_embeddings(self, new_embeddings):
|
1084 |
+
self.lm_head = new_embeddings
|
1085 |
+
|
1086 |
+
def set_decoder(self, decoder):
|
1087 |
+
self.model = decoder
|
1088 |
+
|
1089 |
+
def get_decoder(self):
|
1090 |
+
return self.model
|
1091 |
+
|
1092 |
+
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
1093 |
+
@replace_return_docstrings(
|
1094 |
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
1095 |
+
)
|
1096 |
+
def forward(
|
1097 |
+
self,
|
1098 |
+
input_ids: torch.LongTensor = None,
|
1099 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1100 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1101 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1102 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1103 |
+
labels: Optional[torch.LongTensor] = None,
|
1104 |
+
use_cache: Optional[bool] = None,
|
1105 |
+
output_attentions: Optional[bool] = None,
|
1106 |
+
output_hidden_states: Optional[bool] = None,
|
1107 |
+
return_dict: Optional[bool] = None,
|
1108 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1109 |
+
r"""
|
1110 |
+
Args:
|
1111 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1112 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1113 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1114 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1115 |
+
|
1116 |
+
Returns:
|
1117 |
+
|
1118 |
+
Example:
|
1119 |
+
|
1120 |
+
```python
|
1121 |
+
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
1122 |
+
|
1123 |
+
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
1124 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
1125 |
+
|
1126 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1127 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1128 |
+
|
1129 |
+
>>> # Generate
|
1130 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1131 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1132 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1133 |
+
```"""
|
1134 |
+
|
1135 |
+
output_attentions = (
|
1136 |
+
output_attentions
|
1137 |
+
if output_attentions is not None
|
1138 |
+
else self.config.output_attentions
|
1139 |
+
)
|
1140 |
+
output_hidden_states = (
|
1141 |
+
output_hidden_states
|
1142 |
+
if output_hidden_states is not None
|
1143 |
+
else self.config.output_hidden_states
|
1144 |
+
)
|
1145 |
+
return_dict = (
|
1146 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1147 |
+
)
|
1148 |
+
|
1149 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1150 |
+
outputs = self.model(
|
1151 |
+
input_ids=input_ids,
|
1152 |
+
attention_mask=attention_mask,
|
1153 |
+
position_ids=position_ids,
|
1154 |
+
past_key_values=past_key_values,
|
1155 |
+
inputs_embeds=inputs_embeds,
|
1156 |
+
use_cache=use_cache,
|
1157 |
+
output_attentions=output_attentions,
|
1158 |
+
output_hidden_states=output_hidden_states,
|
1159 |
+
return_dict=return_dict,
|
1160 |
+
)
|
1161 |
+
|
1162 |
+
hidden_states = outputs[0]
|
1163 |
+
logits = self.lm_head(hidden_states)
|
1164 |
+
logits = logits.float()
|
1165 |
+
|
1166 |
+
loss = None
|
1167 |
+
if labels is not None:
|
1168 |
+
# Shift so that tokens < n predict n
|
1169 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1170 |
+
shift_labels = labels[..., 1:].contiguous()
|
1171 |
+
# Flatten the tokens
|
1172 |
+
loss_fct = CrossEntropyLoss()
|
1173 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1174 |
+
shift_labels = shift_labels.view(-1)
|
1175 |
+
# Enable model parallelism
|
1176 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1177 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1178 |
+
|
1179 |
+
if not return_dict:
|
1180 |
+
output = (logits,) + outputs[1:]
|
1181 |
+
return (loss,) + output if loss is not None else output
|
1182 |
+
|
1183 |
+
return CausalLMOutputWithPast(
|
1184 |
+
loss=loss,
|
1185 |
+
logits=logits,
|
1186 |
+
past_key_values=outputs.past_key_values,
|
1187 |
+
hidden_states=outputs.hidden_states,
|
1188 |
+
attentions=outputs.attentions,
|
1189 |
+
)
|
1190 |
+
|
1191 |
+
def prepare_inputs_for_generation(
|
1192 |
+
self,
|
1193 |
+
input_ids,
|
1194 |
+
past_key_values=None,
|
1195 |
+
attention_mask=None,
|
1196 |
+
inputs_embeds=None,
|
1197 |
+
**kwargs,
|
1198 |
+
):
|
1199 |
+
# Omit tokens covered by past_key_values
|
1200 |
+
if past_key_values:
|
1201 |
+
past_length = past_key_values[0][0].shape[2]
|
1202 |
+
|
1203 |
+
# Some generation methods already pass only the last input ID
|
1204 |
+
if input_ids.shape[1] > past_length:
|
1205 |
+
remove_prefix_length = past_length
|
1206 |
+
else:
|
1207 |
+
# Default to old behavior: keep only final ID
|
1208 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
1209 |
+
|
1210 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
1211 |
+
|
1212 |
+
position_ids = kwargs.get("position_ids", None)
|
1213 |
+
if attention_mask is not None and position_ids is None:
|
1214 |
+
# create position_ids on the fly for batch generation
|
1215 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1216 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1217 |
+
if past_key_values:
|
1218 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1219 |
+
|
1220 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1221 |
+
if inputs_embeds is not None and past_key_values is None:
|
1222 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1223 |
+
else:
|
1224 |
+
model_inputs = {"input_ids": input_ids}
|
1225 |
+
|
1226 |
+
model_inputs.update(
|
1227 |
+
{
|
1228 |
+
"position_ids": position_ids,
|
1229 |
+
"past_key_values": past_key_values,
|
1230 |
+
"use_cache": kwargs.get("use_cache"),
|
1231 |
+
"attention_mask": attention_mask,
|
1232 |
+
}
|
1233 |
+
)
|
1234 |
+
return model_inputs
|
1235 |
+
|
1236 |
+
@staticmethod
|
1237 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1238 |
+
reordered_past = ()
|
1239 |
+
for layer_past in past_key_values:
|
1240 |
+
reordered_past += (
|
1241 |
+
tuple(
|
1242 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
1243 |
+
for past_state in layer_past
|
1244 |
+
),
|
1245 |
+
)
|
1246 |
+
return reordered_past
|
1247 |
+
|
1248 |
+
|
1249 |
+
@add_start_docstrings(
|
1250 |
+
"""
|
1251 |
+
The Mistral Model transformer with a sequence classification head on top (linear layer).
|
1252 |
+
|
1253 |
+
[`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1254 |
+
(e.g. GPT-2) do.
|
1255 |
+
|
1256 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1257 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1258 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1259 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1260 |
+
each row of the batch).
|
1261 |
+
""",
|
1262 |
+
MISTRAL_START_DOCSTRING,
|
1263 |
+
)
|
1264 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
|
1265 |
+
class EvoMistralForSequenceClassification(MistralPreTrainedModel):
|
1266 |
+
def __init__(self, config):
|
1267 |
+
super().__init__(config)
|
1268 |
+
self.num_labels = config.num_labels
|
1269 |
+
self.model = EvoMistralModel(config)
|
1270 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1271 |
+
|
1272 |
+
# Initialize weights and apply final processing
|
1273 |
+
self.post_init()
|
1274 |
+
|
1275 |
+
def get_input_embeddings(self):
|
1276 |
+
return self.model.embed_tokens
|
1277 |
+
|
1278 |
+
def set_input_embeddings(self, value):
|
1279 |
+
self.model.embed_tokens = value
|
1280 |
+
|
1281 |
+
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
1282 |
+
def forward(
|
1283 |
+
self,
|
1284 |
+
input_ids: torch.LongTensor = None,
|
1285 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1286 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1287 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1288 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1289 |
+
labels: Optional[torch.LongTensor] = None,
|
1290 |
+
use_cache: Optional[bool] = None,
|
1291 |
+
output_attentions: Optional[bool] = None,
|
1292 |
+
output_hidden_states: Optional[bool] = None,
|
1293 |
+
return_dict: Optional[bool] = None,
|
1294 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1295 |
+
r"""
|
1296 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1297 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1298 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1299 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1300 |
+
"""
|
1301 |
+
return_dict = (
|
1302 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1303 |
+
)
|
1304 |
+
|
1305 |
+
transformer_outputs = self.model(
|
1306 |
+
input_ids,
|
1307 |
+
attention_mask=attention_mask,
|
1308 |
+
position_ids=position_ids,
|
1309 |
+
past_key_values=past_key_values,
|
1310 |
+
inputs_embeds=inputs_embeds,
|
1311 |
+
use_cache=use_cache,
|
1312 |
+
output_attentions=output_attentions,
|
1313 |
+
output_hidden_states=output_hidden_states,
|
1314 |
+
return_dict=return_dict,
|
1315 |
+
)
|
1316 |
+
hidden_states = transformer_outputs[0]
|
1317 |
+
logits = self.score(hidden_states)
|
1318 |
+
|
1319 |
+
if input_ids is not None:
|
1320 |
+
batch_size = input_ids.shape[0]
|
1321 |
+
else:
|
1322 |
+
batch_size = inputs_embeds.shape[0]
|
1323 |
+
|
1324 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1325 |
+
raise ValueError(
|
1326 |
+
"Cannot handle batch sizes > 1 if no padding token is defined."
|
1327 |
+
)
|
1328 |
+
if self.config.pad_token_id is None:
|
1329 |
+
sequence_lengths = -1
|
1330 |
+
else:
|
1331 |
+
if input_ids is not None:
|
1332 |
+
sequence_lengths = (
|
1333 |
+
torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
|
1334 |
+
).to(logits.device)
|
1335 |
+
else:
|
1336 |
+
sequence_lengths = -1
|
1337 |
+
|
1338 |
+
pooled_logits = logits[
|
1339 |
+
torch.arange(batch_size, device=logits.device), sequence_lengths
|
1340 |
+
]
|
1341 |
+
|
1342 |
+
loss = None
|
1343 |
+
if labels is not None:
|
1344 |
+
labels = labels.to(logits.device)
|
1345 |
+
if self.config.problem_type is None:
|
1346 |
+
if self.num_labels == 1:
|
1347 |
+
self.config.problem_type = "regression"
|
1348 |
+
elif self.num_labels > 1 and (
|
1349 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
1350 |
+
):
|
1351 |
+
self.config.problem_type = "single_label_classification"
|
1352 |
+
else:
|
1353 |
+
self.config.problem_type = "multi_label_classification"
|
1354 |
+
|
1355 |
+
if self.config.problem_type == "regression":
|
1356 |
+
loss_fct = MSELoss()
|
1357 |
+
if self.num_labels == 1:
|
1358 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1359 |
+
else:
|
1360 |
+
loss = loss_fct(pooled_logits, labels)
|
1361 |
+
elif self.config.problem_type == "single_label_classification":
|
1362 |
+
loss_fct = CrossEntropyLoss()
|
1363 |
+
loss = loss_fct(
|
1364 |
+
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
1365 |
+
)
|
1366 |
+
elif self.config.problem_type == "multi_label_classification":
|
1367 |
+
loss_fct = BCEWithLogitsLoss()
|
1368 |
+
loss = loss_fct(pooled_logits, labels)
|
1369 |
+
if not return_dict:
|
1370 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1371 |
+
return ((loss,) + output) if loss is not None else output
|
1372 |
+
|
1373 |
+
return SequenceClassifierOutputWithPast(
|
1374 |
+
loss=loss,
|
1375 |
+
logits=pooled_logits,
|
1376 |
+
past_key_values=transformer_outputs.past_key_values,
|
1377 |
+
hidden_states=transformer_outputs.hidden_states,
|
1378 |
+
attentions=transformer_outputs.attentions,
|
1379 |
+
)
|
special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<unk>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
|
3 |
+
size 493443
|
tokenizer_config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<s>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "</s>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"additional_special_tokens": [],
|
31 |
+
"bos_token": "<s>",
|
32 |
+
"clean_up_tokenization_spaces": false,
|
33 |
+
"eos_token": "</s>",
|
34 |
+
"legacy": true,
|
35 |
+
"model_max_length": 1000000000000000019884624838656,
|
36 |
+
"pad_token": null,
|
37 |
+
"sp_model_kwargs": {},
|
38 |
+
"spaces_between_special_tokens": false,
|
39 |
+
"tokenizer_class": "LlamaTokenizer",
|
40 |
+
"unk_token": "<unk>",
|
41 |
+
"use_default_system_prompt": false
|
42 |
+
}
|