Upload 3 files
#1
by
bapatra
- opened
- .gitattributes +35 -0
- .gitignore +0 -160
- CODE_OF_CONDUCT.md +9 -0
- LICENSE +21 -0
- README.md +275 -0
- SECURITY.md +41 -0
- SUPPORT.md +25 -0
- cl100k_base.tiktoken +0 -0
- config.json +47 -0
- configuration_phi3_small.py +250 -0
- generation_config.json +9 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +426 -0
- modeling_phi3_small.py +1140 -0
- positional_embedding.py +288 -0
- special_tokens_map.json +5 -0
- tokenization_phi3_small.py +315 -0
- tokenizer_config.json +16 -0
- triton_blocksparse_attention_layer.py +176 -0
- triton_flash_blocksparse_attn.py +1943 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
# Byte-compiled / optimized / DLL files
|
2 |
-
__pycache__/
|
3 |
-
*.py[cod]
|
4 |
-
*$py.class
|
5 |
-
|
6 |
-
# C extensions
|
7 |
-
*.so
|
8 |
-
|
9 |
-
# Distribution / packaging
|
10 |
-
.Python
|
11 |
-
build/
|
12 |
-
develop-eggs/
|
13 |
-
dist/
|
14 |
-
downloads/
|
15 |
-
eggs/
|
16 |
-
.eggs/
|
17 |
-
lib/
|
18 |
-
lib64/
|
19 |
-
parts/
|
20 |
-
sdist/
|
21 |
-
var/
|
22 |
-
wheels/
|
23 |
-
share/python-wheels/
|
24 |
-
*.egg-info/
|
25 |
-
.installed.cfg
|
26 |
-
*.egg
|
27 |
-
MANIFEST
|
28 |
-
|
29 |
-
# PyInstaller
|
30 |
-
# Usually these files are written by a python script from a template
|
31 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
-
*.manifest
|
33 |
-
*.spec
|
34 |
-
|
35 |
-
# Installer logs
|
36 |
-
pip-log.txt
|
37 |
-
pip-delete-this-directory.txt
|
38 |
-
|
39 |
-
# Unit test / coverage reports
|
40 |
-
htmlcov/
|
41 |
-
.tox/
|
42 |
-
.nox/
|
43 |
-
.coverage
|
44 |
-
.coverage.*
|
45 |
-
.cache
|
46 |
-
nosetests.xml
|
47 |
-
coverage.xml
|
48 |
-
*.cover
|
49 |
-
*.py,cover
|
50 |
-
.hypothesis/
|
51 |
-
.pytest_cache/
|
52 |
-
cover/
|
53 |
-
|
54 |
-
# Translations
|
55 |
-
*.mo
|
56 |
-
*.pot
|
57 |
-
|
58 |
-
# Django stuff:
|
59 |
-
*.log
|
60 |
-
local_settings.py
|
61 |
-
db.sqlite3
|
62 |
-
db.sqlite3-journal
|
63 |
-
|
64 |
-
# Flask stuff:
|
65 |
-
instance/
|
66 |
-
.webassets-cache
|
67 |
-
|
68 |
-
# Scrapy stuff:
|
69 |
-
.scrapy
|
70 |
-
|
71 |
-
# Sphinx documentation
|
72 |
-
docs/_build/
|
73 |
-
|
74 |
-
# PyBuilder
|
75 |
-
.pybuilder/
|
76 |
-
target/
|
77 |
-
|
78 |
-
# Jupyter Notebook
|
79 |
-
.ipynb_checkpoints
|
80 |
-
|
81 |
-
# IPython
|
82 |
-
profile_default/
|
83 |
-
ipython_config.py
|
84 |
-
|
85 |
-
# pyenv
|
86 |
-
# For a library or package, you might want to ignore these files since the code is
|
87 |
-
# intended to run in multiple environments; otherwise, check them in:
|
88 |
-
# .python-version
|
89 |
-
|
90 |
-
# pipenv
|
91 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
-
# install all needed dependencies.
|
95 |
-
#Pipfile.lock
|
96 |
-
|
97 |
-
# poetry
|
98 |
-
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
-
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
-
# commonly ignored for libraries.
|
101 |
-
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
-
#poetry.lock
|
103 |
-
|
104 |
-
# pdm
|
105 |
-
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
-
#pdm.lock
|
107 |
-
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
-
# in version control.
|
109 |
-
# https://pdm.fming.dev/#use-with-ide
|
110 |
-
.pdm.toml
|
111 |
-
|
112 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
-
__pypackages__/
|
114 |
-
|
115 |
-
# Celery stuff
|
116 |
-
celerybeat-schedule
|
117 |
-
celerybeat.pid
|
118 |
-
|
119 |
-
# SageMath parsed files
|
120 |
-
*.sage.py
|
121 |
-
|
122 |
-
# Environments
|
123 |
-
.env
|
124 |
-
.venv
|
125 |
-
env/
|
126 |
-
venv/
|
127 |
-
ENV/
|
128 |
-
env.bak/
|
129 |
-
venv.bak/
|
130 |
-
|
131 |
-
# Spyder project settings
|
132 |
-
.spyderproject
|
133 |
-
.spyproject
|
134 |
-
|
135 |
-
# Rope project settings
|
136 |
-
.ropeproject
|
137 |
-
|
138 |
-
# mkdocs documentation
|
139 |
-
/site
|
140 |
-
|
141 |
-
# mypy
|
142 |
-
.mypy_cache/
|
143 |
-
.dmypy.json
|
144 |
-
dmypy.json
|
145 |
-
|
146 |
-
# Pyre type checker
|
147 |
-
.pyre/
|
148 |
-
|
149 |
-
# pytype static type analyzer
|
150 |
-
.pytype/
|
151 |
-
|
152 |
-
# Cython debug symbols
|
153 |
-
cython_debug/
|
154 |
-
|
155 |
-
# PyCharm
|
156 |
-
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
-
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
-
#.idea/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Microsoft Open Source Code of Conduct
|
2 |
+
|
3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
+
|
5 |
+
Resources:
|
6 |
+
|
7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
+
- Contact [[email protected]](mailto:[email protected]) with questions or concerns
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Microsoft Corporation.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
README.md
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
license_link: https://huggingface.co/microsoft/Phi-3-small-8k-instruct/resolve/main/LICENSE
|
4 |
+
|
5 |
+
language:
|
6 |
+
- multilingual
|
7 |
+
pipeline_tag: text-generation
|
8 |
+
tags:
|
9 |
+
- nlp
|
10 |
+
- code
|
11 |
+
inference:
|
12 |
+
parameters:
|
13 |
+
temperature: 0.7
|
14 |
+
widget:
|
15 |
+
- messages:
|
16 |
+
- role: user
|
17 |
+
content: Can you provide ways to eat combinations of bananas and dragonfruits?
|
18 |
+
---
|
19 |
+
## Model Summary
|
20 |
+
|
21 |
+
The Phi-3-Small-8K-Instruct is a 7B parameters, lightweight, state-of-the-art open model trained with the Phi-3 datasets that includes both synthetic data and the filtered publicly available websites data with a focus on high-quality and reasoning dense properties.
|
22 |
+
The model belongs to the Phi-3 family with the Small version in two variants [8K](https://huggingface.co/microsoft/Phi-3-small-8k-instruct) and [128K](https://huggingface.co/microsoft/Phi-3-small-128k-instruct) which is the context length (in tokens) that it can support.
|
23 |
+
|
24 |
+
The model has underwent a post-training process that incorporates both supervised fine-tuning and direct preference optimization for the instruction following and safety measures.
|
25 |
+
When assessed against benchmarks testing common sense, language understanding, math, code, long context and logical reasoning, Phi-3-Small-8K-Instruct showcased a robust and state-of-the-art performance among models with less than 13 billion parameters.
|
26 |
+
|
27 |
+
Resources and Technical Documentation:
|
28 |
+
|
29 |
+
+ [Phi-3 Microsoft Blog](https://aka.ms/phi3blog-april)
|
30 |
+
+ [Phi-3 Technical Report](https://aka.ms/phi3-tech-report)
|
31 |
+
+ [Phi-3 on Azure AI Studio](https://aka.ms/phi3-azure-ai)
|
32 |
+
|
33 |
+
| | Short Context | Long Context |
|
34 |
+
| ------- | ------------- | ------------ |
|
35 |
+
| Mini | 4K [[HF]](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx) ; [[GGUF]](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf) | 128K [[HF]](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct-onnx)|
|
36 |
+
| Small | 8K [[HF]](https://huggingface.co/microsoft/Phi-3-small-8k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-small-8k-instruct-onnx) | 128K [[HF]](https://huggingface.co/microsoft/Phi-3-small-128k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-small-128k-instruct-onnx)|
|
37 |
+
| Medium | 4K [[HF]](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct-onnx) | 128K [[HF]](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct) ; [[ONNX]](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct-onnx)|
|
38 |
+
|
39 |
+
## Intended Uses
|
40 |
+
|
41 |
+
**Primary use cases**
|
42 |
+
|
43 |
+
The model is intended for broad commercial and research use in English. The model provides uses for general purpose AI systems and applications which require:
|
44 |
+
|
45 |
+
1) Memory/compute constrained environments
|
46 |
+
2) Latency bound scenarios
|
47 |
+
3) Strong reasoning (especially code, math and logic)
|
48 |
+
|
49 |
+
Our model is designed to accelerate research on language and multimodal models, for use as a building block for generative AI powered features.
|
50 |
+
|
51 |
+
**Use case considerations**
|
52 |
+
|
53 |
+
Our models are not specifically designed or evaluated for all downstream purposes. Developers should consider common limitations of language models as they select use cases, and evaluate and mitigate for accuracy, safety, and fariness before using within a specific downstream use case, particularly for high risk scenarios. Developers should be aware of and adhere to applicable laws or regulations (including privacy, trade compliance laws, etc.) that are relevant to their use case.
|
54 |
+
|
55 |
+
Nothing contained in this Model Card should be interpreted as or deemed a restriction or modification to the license the model is released under.
|
56 |
+
|
57 |
+
## How to Use
|
58 |
+
|
59 |
+
Phi-3-Small-8K-Instruct has been integrated in the development version () of `transformers`. Until the official version is released through `pip`, ensure that you are doing one of the following:
|
60 |
+
* Install tiktoken (0.6.0) ans triton (2.3.0)
|
61 |
+
|
62 |
+
* When loading the model, ensure that `trust_remote_code=True` is passed as an argument of the `from_pretrained()` function.
|
63 |
+
|
64 |
+
* Update your local `transformers` to the development version: `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers`. The previous command is an alternative to cloning and installing from the source.
|
65 |
+
|
66 |
+
The current `transformers` version can be verified with: `pip list | grep transformers`.
|
67 |
+
|
68 |
+
Phi-3-Small-8K-Instruct is also available in [Azure AI](https://ai.azure.com/explore/models?&selectedCollection=phi).
|
69 |
+
|
70 |
+
### Tokenizer
|
71 |
+
|
72 |
+
Phi-3-Small-8K-Instruct supports a vocabulary size of up to `100352` tokens.
|
73 |
+
|
74 |
+
### Chat Format
|
75 |
+
|
76 |
+
Given the nature of the training data, the Phi-3-Small-8K-Instruct model is best suited for prompts using the chat format as follows.
|
77 |
+
You can provide the prompt as a question with a generic template as follow:
|
78 |
+
```markdown
|
79 |
+
<|endoftext|><|user|>\nQuestion <|end|>\n<|assistant|>
|
80 |
+
```
|
81 |
+
For example:
|
82 |
+
```markdown
|
83 |
+
<|endoftext|><|user|>
|
84 |
+
How to explain Internet for a medieval knight?<|end|>
|
85 |
+
<|assistant|>
|
86 |
+
```
|
87 |
+
|
88 |
+
where the model generates the text after `<|assistant|>` . In case of few-shots prompt, the prompt can be formatted as the following:
|
89 |
+
|
90 |
+
```markdown
|
91 |
+
<|endoftext|><|user|>
|
92 |
+
I am going to Paris, what should I see?<|end|>
|
93 |
+
<|assistant|>
|
94 |
+
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."<|end|>
|
95 |
+
<|user|>
|
96 |
+
What is so great about #1?<|end|>
|
97 |
+
<|assistant|>
|
98 |
+
```
|
99 |
+
|
100 |
+
### Sample inference code
|
101 |
+
|
102 |
+
This code snippets show how to get quickly started with running the model on a GPU:
|
103 |
+
|
104 |
+
```python
|
105 |
+
import torch
|
106 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
107 |
+
|
108 |
+
torch.random.manual_seed(0)
|
109 |
+
model_id = "microsoft/Phi-3-small-8k-instruct"
|
110 |
+
model = AutoModelForCausalLM.from_pretrained(
|
111 |
+
model_id,
|
112 |
+
torch_dtype="auto",
|
113 |
+
trust_remote_code=True,
|
114 |
+
)
|
115 |
+
assert torch.cuda.is_available(), "This model needs a GPU to run ..."
|
116 |
+
device = torch.cuda.current_device()
|
117 |
+
model = model.to(device)
|
118 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
119 |
+
|
120 |
+
messages = [
|
121 |
+
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
|
122 |
+
{"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
|
123 |
+
{"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
|
124 |
+
]
|
125 |
+
|
126 |
+
pipe = pipeline(
|
127 |
+
"text-generation",
|
128 |
+
model=model,
|
129 |
+
tokenizer=tokenizer,
|
130 |
+
device=device
|
131 |
+
)
|
132 |
+
|
133 |
+
generation_args = {
|
134 |
+
"max_new_tokens": 500,
|
135 |
+
"return_full_text": False,
|
136 |
+
"temperature": 0.0,
|
137 |
+
"do_sample": False,
|
138 |
+
}
|
139 |
+
|
140 |
+
output = pipe(messages, **generation_args)
|
141 |
+
print(output[0]['generated_text'])
|
142 |
+
```
|
143 |
+
|
144 |
+
*Some applications/frameworks might not include a BOS token (`<|endoftext|>`) at the start of the conversation. Please ensure that it is included since it provides more reliable results.*
|
145 |
+
|
146 |
+
## Responsible AI Considerations
|
147 |
+
|
148 |
+
Like other language models, the Phi series models can potentially behave in ways that are unfair, unreliable, or offensive. Some of the limiting behaviors to be aware of include:
|
149 |
+
|
150 |
+
+ Quality of Service: the Phi models are trained primarily on English text. Languages other than English will experience worse performance. English language varieties with less representation in the training data might experience worse performance than standard American English.
|
151 |
+
+ Representation of Harms & Perpetuation of Stereotypes: These models can over- or under-represent groups of people, erase representation of some groups, or reinforce demeaning or negative stereotypes. Despite safety post-training, these limitations may still be present due to differing levels of representation of different groups or prevalence of examples of negative stereotypes in training data that reflect real-world patterns and societal biases.
|
152 |
+
+ Inappropriate or Offensive Content: these models may produce other types of inappropriate or offensive content, which may make it inappropriate to deploy for sensitive contexts without additional mitigations that are specific to the use case.
|
153 |
+
+ Information Reliability: Language models can generate nonsensical content or fabricate content that might sound reasonable but is inaccurate or outdated.
|
154 |
+
+ Limited Scope for Code: Majority of Phi-3 training data is based in Python and use common packages such as "typing, math, random, collections, datetime, itertools". If the model generates Python scripts that utilize other packages or scripts in other languages, we strongly recommend users manually verify all API uses.
|
155 |
+
|
156 |
+
Developers should apply responsible AI best practices and are responsible for ensuring that a specific use case complies with relevant laws and regulations (e.g. privacy, trade, etc.). Important areas for consideration include:
|
157 |
+
|
158 |
+
+ Allocation: Models may not be suitable for scenarios that could have consequential impact on legal status or the allocation of resources or life opportunities (ex: housing, employment, credit, etc.) without further assessments and additional debiasing techniques.
|
159 |
+
+ High-Risk Scenarios: Developers should assess suitability of using models in high-risk scenarios where unfair, unreliable or offensive outputs might be extremely costly or lead to harm. This includes providing advice in sensitive or expert domains where accuracy and reliability are critical (ex: legal or health advice). Additional safeguards should be implemented at the application level according to the deployment context.
|
160 |
+
+ Misinformation: Models may produce inaccurate information. Developers should follow transparency best practices and inform end-users they are interacting with an AI system. At the application level, developers can build feedback mechanisms and pipelines to ground responses in use-case specific, contextual information, a technique known as Retrieval Augmented Generation (RAG).
|
161 |
+
+ Generation of Harmful Content: Developers should assess outputs for their context and use available safety classifiers or custom solutions appropriate for their use case.
|
162 |
+
+ Misuse: Other forms of misuse such as fraud, spam, or malware production may be possible, and developers should ensure that their applications do not violate applicable laws and regulations.
|
163 |
+
|
164 |
+
|
165 |
+
## Training
|
166 |
+
|
167 |
+
### Model
|
168 |
+
|
169 |
+
* Architecture: Phi-3 Small-8K-Instruct has 7B parameters and is a dense decoder-only Transformer model. The model is fine-tuned with Supervised fine-tuning (SFT) and Direct Preference Optimization (DPO) to ensure alignment with human preferences and safety guidlines.
|
170 |
+
* Inputs: Text. It is best suited for prompts using chat format.
|
171 |
+
* Context length: 8K tokens
|
172 |
+
* GPUs: 1024 H100-80G
|
173 |
+
* Training time: 18 days
|
174 |
+
* Training data: 4.8T tokens
|
175 |
+
* Outputs: Generated text in response to the input
|
176 |
+
* Dates: Our models were trained between February and April 2024
|
177 |
+
* Status: This is a static model trained on an offline dataset with cutoff date October 2023. Future versions of the tuned models may be released as we improve models.
|
178 |
+
* Release dates The model weight is released on May 21, 2024.
|
179 |
+
|
180 |
+
### Datasets
|
181 |
+
|
182 |
+
Our training data includes a wide variety of sources, totaling 4.8 trillion tokens (including 10% multilingual), and is a combination of
|
183 |
+
1) Publicly available documents filtered rigorously for quality, selected high-quality educational data, and code;
|
184 |
+
2) Newly created synthetic, “textbook-like” data for the purpose of teaching math, coding, common sense reasoning, general knowledge of the world (science, daily activities, theory of mind, etc.);
|
185 |
+
3) High quality chat format supervised data covering various topics to reflect human preferences on different aspects such as instruct-following, truthfulness, honesty and helpfulness.
|
186 |
+
|
187 |
+
We are focusing on the quality of data that could potentially improve the reasoning ability for the model, and we filter the publicly available documents to contain the correct level of knowledge. As an example, the result of a game in premier league in a particular day might be good training data for frontier models, but we need to remove such information to leave more model capacity for reasoning for the small size models. More details about data can be found in the [Phi-3 Technical Report](https://aka.ms/phi3-tech-report).
|
188 |
+
|
189 |
+
## Benchmarks
|
190 |
+
|
191 |
+
We report the results for Phi-3-Small-8K-Instruct on standard open-source benchmarks measuring the model's reasoning ability (both common sense reasoning and logical reasoning). We compare to Mixtral-8x7b, Gemini-Pro, Gemma 7B, Llama-3-8B-Instruct, GPT-3.5-Turbo-1106, and GPT-4-Turbo-1106.
|
192 |
+
|
193 |
+
All the reported numbers are produced with the exact same pipeline to ensure that the numbers are comparable. These numbers might differ from other published numbers due to slightly different choices in the evaluation.
|
194 |
+
|
195 |
+
As is now standard, we use few-shot prompts to evaluate the models, at temperature 0.
|
196 |
+
The prompts and number of shots are part of a Microsoft internal tool to evaluate language models, and in particular we did no optimization to the pipeline for Phi-3.
|
197 |
+
More specifically, we do not change prompts, pick different few-shot examples, change prompt format, or do any other form of optimization for the model.
|
198 |
+
|
199 |
+
The number of k–shot examples is listed per-benchmark.
|
200 |
+
|
201 |
+
|Benchmark|Phi-3-Small-8K-Instruct<br>7b|Gemma<br>7B|Mixtral<br>8x7B|Llama-3-Instruct<br>8b|GPT-3.5-Turbo<br>version 1106|Gemini<br>Pro|GPT-4-Turbo<br>version 1106 (Chat)|
|
202 |
+
|---------|-----------------------|--------|-------------|-------------------|-----------------|----------|------------------------|
|
203 |
+
|AGI Eval<br>5-shot|45.1|42.1|45.2|42.0|48.4|49.0|59.6|
|
204 |
+
|MMLU<br>5-shot|75.7|63.6|70.5|66.5|71.4|66.7|84.0|
|
205 |
+
|BigBench Hard<br>3-shot|79.1|59.6|69.7|51.5|68.3|75.6|87.7|
|
206 |
+
|ANLI<br>7-shot|58.1|48.7|55.2|57.3|58.1|64.2|71.7|
|
207 |
+
|HellaSwag<br>5-shot|77.0|49.8|70.4|71.1|78.8|76.2|88.3|
|
208 |
+
|ARC Challenge<br>10-shot|90.7|78.3|87.3|82.8|87.4|88.3|95.6|
|
209 |
+
|ARC Easy<br>10-shot|97.0|91.4|95.6|93.4|96.3|96.1|98.8|
|
210 |
+
|BoolQ<br>2-shot|84.8|66.0|76.6|80.9|79.1|86.4|91.3|
|
211 |
+
|CommonsenseQA<br>10-shot|80.0|76.2|78.1|79.0|79.6|81.8|86.7|
|
212 |
+
|MedQA<br>2-shot|65.4|49.6|62.2|60.5|63.4|58.2|83.7|
|
213 |
+
|OpenBookQA<br>10-shot|88.0|78.6|85.8|82.6|86.0|86.4|93.4|
|
214 |
+
|PIQA<br>5-shot|86.9|78.1|86.0|75.7|86.6|86.2|90.1|
|
215 |
+
|Social IQA<br>5-shot|79.2|65.5|75.9|73.9|68.3|75.4|81.7|
|
216 |
+
|TruthfulQA (MC2)<br>10-shot|70.2|52.1|60.1|63.2|67.7|72.6|85.2|
|
217 |
+
|WinoGrande<br>5-shot|81.5|55.6|62.0|65.0|68.8|72.2|86.7|
|
218 |
+
|TriviaQA<br>5-shot|58.1|72.3|82.2|67.7|85.8|80.2|73.3|
|
219 |
+
|GSM8K Chain of Thought<br>8-shot|89.6|59.8|64.7|77.4|78.1|80.4|94.2|
|
220 |
+
|HumanEval<br>0-shot|61.0|34.1|37.8|60.4|62.2|64.4|79.9|
|
221 |
+
|MBPP<br>3-shot|71.7|51.5|60.2|67.7|77.8|73.2|86.7|
|
222 |
+
|Average|75.7|61.8|69.8|69.4|74.3|75.4|85.2|
|
223 |
+
|
224 |
+
We take a closer look at different categories across 80 public benchmark datasets at the table below:
|
225 |
+
|
226 |
+
|Benchmark|Phi-3-Small-8K-Instruct<br>7b|Gemma<br>7B|Mixtral<br>8x7B|Llama-3-Instruct<br>8b|GPT-3.5-Turbo<br>version 1106|Gemini<br>Pro|GPT-4-Turbo<br>version 1106 (Chat)|
|
227 |
+
|--------|------------------------|--------|-------------|-------------------|-------------------|----------|------------------------|
|
228 |
+
|Popular aggregated benchmark|71.1|59.4|66.2|59.9|67.0|67.5|80.5|
|
229 |
+
|Reasoning|82.4|69.1|77.0|75.7|78.3|80.4|89.3|
|
230 |
+
|Language understanding|70.6|58.4|64.9|65.4|70.4|75.3|81.6|
|
231 |
+
|Code generation|60.7|45.6|52.7|56.4|70.4|66.7|76.1|
|
232 |
+
|Math|51.6|35.8|40.3|41.1|52.8|50.9|67.1|
|
233 |
+
|Factual knowledge|38.6|46.7|58.6|43.1|63.4|54.6|45.9|
|
234 |
+
|Multilingual|62.5|63.2|63.4|65.0|69.1|76.5|82.0|
|
235 |
+
|Robustness|72.9|38.4|51.0|64.5|69.3|69.7|84.6|
|
236 |
+
|
237 |
+
|
238 |
+
## Software
|
239 |
+
|
240 |
+
* [PyTorch](https://github.com/pytorch/pytorch)
|
241 |
+
* [DeepSpeed](https://github.com/microsoft/DeepSpeed)
|
242 |
+
* [Transformers](https://github.com/huggingface/transformers)
|
243 |
+
* [Flash-Attention](https://github.com/HazyResearch/flash-attention)
|
244 |
+
* [Tiktoken](https://github.com/openai/tiktoken)
|
245 |
+
* [Triton](https://github.com/openai/triton)
|
246 |
+
|
247 |
+
## Hardware
|
248 |
+
Note that by default, the Phi-3-Small model uses flash attention, which requires certain types of GPU hardware to run. We have tested on the following GPU types:
|
249 |
+
* NVIDIA A100
|
250 |
+
* NVIDIA A6000
|
251 |
+
* NVIDIA H100
|
252 |
+
|
253 |
+
If you want to run the model on:
|
254 |
+
+ Optimized inference on GPU, CPU, and Mobile: use the **ONNX** models [8K](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct-onnx)
|
255 |
+
|
256 |
+
|
257 |
+
## Cross Platform Support
|
258 |
+
|
259 |
+
ONNX runtime ecosystem now supports Phi3 small models across platforms and hardware.
|
260 |
+
Optimized phi-3 models are also published here in ONNX format, to run with ONNX Runtime on CPU and GPU across devices, including server platforms, Windows, Linux and Mac desktops, and mobile CPUs, with the precision best suited to each of these targets. DirectML GPU acceleration is supported for Windows desktops GPUs (AMD, Intel, and NVIDIA).
|
261 |
+
Along with DML, ONNX Runtime provides cross platform support for Phi3 Small across a range of devices CPU, GPU, and mobile.
|
262 |
+
Here are some of the optimized configurations we have added:
|
263 |
+
|
264 |
+
1. ONNX models for int4 DML: Quantized to int4 via AWQ
|
265 |
+
2. ONNX model for fp16 CUDA
|
266 |
+
3. ONNX model for int4 CUDA: Quantized to int4 via RTN
|
267 |
+
4. ONNX model for int4 CPU and Mobile: Quantized to int4 via RTN
|
268 |
+
|
269 |
+
## License
|
270 |
+
|
271 |
+
The model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-small-8k/resolve/main/LICENSE).
|
272 |
+
|
273 |
+
## Trademarks
|
274 |
+
|
275 |
+
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
|
SECURITY.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
|
2 |
+
|
3 |
+
## Security
|
4 |
+
|
5 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
|
6 |
+
|
7 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
|
8 |
+
|
9 |
+
## Reporting Security Issues
|
10 |
+
|
11 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
12 |
+
|
13 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
|
14 |
+
|
15 |
+
If you prefer to submit without logging in, send email to [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
|
16 |
+
|
17 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
18 |
+
|
19 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
20 |
+
|
21 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
22 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
23 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
24 |
+
* Any special configuration required to reproduce the issue
|
25 |
+
* Step-by-step instructions to reproduce the issue
|
26 |
+
* Proof-of-concept or exploit code (if possible)
|
27 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
28 |
+
|
29 |
+
This information will help us triage your report more quickly.
|
30 |
+
|
31 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
|
32 |
+
|
33 |
+
## Preferred Languages
|
34 |
+
|
35 |
+
We prefer all communications to be in English.
|
36 |
+
|
37 |
+
## Policy
|
38 |
+
|
39 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
|
40 |
+
|
41 |
+
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
SUPPORT.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: The maintainer of this repo has not yet edited this file
|
2 |
+
|
3 |
+
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
4 |
+
|
5 |
+
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
6 |
+
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
|
7 |
+
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
|
8 |
+
|
9 |
+
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
10 |
+
|
11 |
+
# Support
|
12 |
+
|
13 |
+
## How to file issues and get help
|
14 |
+
|
15 |
+
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
16 |
+
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
17 |
+
feature request as a new Issue.
|
18 |
+
|
19 |
+
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
20 |
+
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
21 |
+
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
22 |
+
|
23 |
+
## Microsoft Support Policy
|
24 |
+
|
25 |
+
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
cl100k_base.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
config.json
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "Phi-3-small-8k-instruct",
|
3 |
+
"architectures": [
|
4 |
+
"Phi3SmallForCausalLM"
|
5 |
+
],
|
6 |
+
"attention_dropout_prob": 0.0,
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_phi3_small.Phi3SmallConfig",
|
9 |
+
"AutoModelForCausalLM": "modeling_phi3_small.Phi3SmallForCausalLM",
|
10 |
+
"AutoTokenizer": "tokenization_phi3_small.Phi3SmallTokenizer"
|
11 |
+
},
|
12 |
+
"blocksparse_block_size": 64,
|
13 |
+
"blocksparse_homo_head_pattern": false,
|
14 |
+
"blocksparse_num_local_blocks": 16,
|
15 |
+
"blocksparse_triton_kernel_block_size": 64,
|
16 |
+
"blocksparse_vert_stride": 8,
|
17 |
+
"bos_token_id": 100257,
|
18 |
+
"dense_attention_every_n_layers": 2,
|
19 |
+
"embedding_dropout_prob": 0.1,
|
20 |
+
"eos_token_id": 100257,
|
21 |
+
"ff_dim_multiplier": null,
|
22 |
+
"ff_intermediate_size": 14336,
|
23 |
+
"ffn_dropout_prob": 0.1,
|
24 |
+
"gegelu_limit": 20.0,
|
25 |
+
"gegelu_pad_to_256": true,
|
26 |
+
"hidden_act": "gegelu",
|
27 |
+
"hidden_size": 4096,
|
28 |
+
"initializer_range": 0.02,
|
29 |
+
"layer_norm_epsilon": 1e-05,
|
30 |
+
"max_position_embeddings": 8192,
|
31 |
+
"model_type": "phi3small",
|
32 |
+
"mup_attn_multiplier": 1.0,
|
33 |
+
"mup_embedding_multiplier": 10.0,
|
34 |
+
"mup_use_scaling": true,
|
35 |
+
"mup_width_multiplier": 8.0,
|
36 |
+
"num_attention_heads": 32,
|
37 |
+
"num_hidden_layers": 32,
|
38 |
+
"num_key_value_heads": 8,
|
39 |
+
"pad_sequence_to_multiple_of_64": true,
|
40 |
+
"reorder_and_upcast_attn": false,
|
41 |
+
"rope_embedding_base": 1000000,
|
42 |
+
"rope_position_scale": 1.0,
|
43 |
+
"torch_dtype": "bfloat16",
|
44 |
+
"transformers_version": "4.38.1",
|
45 |
+
"use_cache": true,
|
46 |
+
"vocab_size": 100352
|
47 |
+
}
|
configuration_phi3_small.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from typing import Any, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
from functools import cached_property
|
22 |
+
|
23 |
+
""" Phi3Small model configuration """
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
def next_mult(x, y):
|
28 |
+
return (x + y - 1) // y * y
|
29 |
+
|
30 |
+
class Phi3SmallConfig(PretrainedConfig):
|
31 |
+
"""
|
32 |
+
This is the configuration class to store the configuration of a `Phi3Small` model. It is used to
|
33 |
+
instantiate a Phi-3-small model according to the specified arguments, defining the model architecture.
|
34 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the Phi-3-small
|
35 |
+
[phi3](https://arxiv.org/pdf/2404.14219) architecture.
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 100352):
|
43 |
+
Vocabulary size of the Phi3Small model. Defines the number of different tokens that can be represented by the
|
44 |
+
`inputs_ids` passed when calling `Phi3Small`.
|
45 |
+
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
46 |
+
The maximum sequence length that this model might safely be used with.
|
47 |
+
rope_embedding_base (`float`, *optional*, defaults to 10^6):
|
48 |
+
The base value for the RoPE (Relative Position Encoding) embedding.
|
49 |
+
rope_position_scale (`float`, *optional*, defaults to 1.0):
|
50 |
+
The scale factor for the RoPE position encoding.
|
51 |
+
rope_scaling (`Optional[Dict[str, Union[float, List[float], int]]]`, *optional*, defaults to None):
|
52 |
+
The scaling configuration used for LongRoPE.
|
53 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
54 |
+
The size of the hidden layers in the model.
|
55 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
56 |
+
The number of layers in the model.
|
57 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
58 |
+
The number of query heads in the model.
|
59 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
60 |
+
The number of key-value heads in the model.
|
61 |
+
hidden_act (`str`, *optional*, defaults to "gegelu"):
|
62 |
+
The activation function used in the model.
|
63 |
+
gegelu_limit (`float`, *optional*, defaults to 20.0):
|
64 |
+
The limit value for the GELU activation function (for numerical stability).
|
65 |
+
gegelu_pad_to_256 (`bool`, *optional*, defaults to True):
|
66 |
+
Whether to pad the intermediate size to a multiple of 256 (for faster matmul ops).
|
67 |
+
ff_dim_multiplier (`Optional[int]`, *optional*, defaults to None):
|
68 |
+
The dimension multiplier for the feed-forward layers.
|
69 |
+
ff_intermediate_size (`Optional[int]`, *optional*, defaults to 14336):
|
70 |
+
The intermediate size for the feed-forward layers.
|
71 |
+
One of `ff_dim_multiplier` or `ff_intermediate_size` must be specified.
|
72 |
+
blocksparse_homo_head_pattern (`bool`, *optional*, defaults to False):
|
73 |
+
Whether to use a homogeneous head pattern for block-sparse attention.
|
74 |
+
blocksparse_block_size (`int`, *optional*, defaults to 64):
|
75 |
+
The block size for block-sparse attention.
|
76 |
+
blocksparse_num_local_blocks (`int`, *optional*, defaults to 16):
|
77 |
+
The number of local blocks for block-sparse attention.
|
78 |
+
The local window used in blocksparse equals `blocksparse_num_local_blocks * blocksparse_block_size`
|
79 |
+
blocksparse_vert_stride (`int`, *optional*, defaults to 8):
|
80 |
+
The vertical stride for block-sparse attention.
|
81 |
+
blocksparse_triton_kernel_block_size (`int`, *optional*, defaults to 64):
|
82 |
+
The kernel block size for block-sparse attention.
|
83 |
+
dense_attention_every_n_layers (`Optional[int]`, *optional*, defaults to 2):
|
84 |
+
The frequency of all dense attention layers in the model
|
85 |
+
embedding_dropout_prob (`float`, *optional*, defaults to 0.1):
|
86 |
+
The dropout probability for the embedding layer.
|
87 |
+
attention_dropout_prob (`float`, *optional*, defaults to 0.0):
|
88 |
+
The dropout probability for the attention layers.
|
89 |
+
ffn_dropout_prob (`float`, *optional*, defaults to 0.1):
|
90 |
+
The dropout probability for the feed-forward layers.
|
91 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
92 |
+
The epsilon value for layer normalization.
|
93 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
94 |
+
The range for weight initialization.
|
95 |
+
mup_use_scaling (`bool`, *optional*, defaults to True):
|
96 |
+
Whether to use scaling for MuP parameters (see: https://arxiv.org/abs/2203.03466).
|
97 |
+
mup_width_multiplier (`bool`, *optional*, defaults to 8.0):
|
98 |
+
The width multiplier for MuP.
|
99 |
+
mup_embedding_multiplier (`bool`, *optional*, defaults to 10.0):
|
100 |
+
The embedding multiplier for MuP.
|
101 |
+
mup_attn_multiplier (`bool`, *optional*, defaults to 1.0):
|
102 |
+
The attention multiplier for MuP.
|
103 |
+
use_cache (`bool`, *optional*, defaults to True):
|
104 |
+
Whether to use cache for the model.
|
105 |
+
bos_token_id (`int`, *optional*, defaults to 100257):
|
106 |
+
The token ID for the beginning of sentence.
|
107 |
+
eos_token_id (`int`, *optional*, defaults to 100257):
|
108 |
+
The token ID for the end of sentence.
|
109 |
+
reorder_and_upcast_attn (`bool`, *optional*, defaults to False):
|
110 |
+
Whether to reorder and upcast attention.
|
111 |
+
pad_sequence_to_multiple_of_64 (`bool`, *optional*, defaults to True):
|
112 |
+
Whether to pad the sequence length to a multiple of 64.
|
113 |
+
**kwargs:
|
114 |
+
Additional keyword arguments.
|
115 |
+
|
116 |
+
Example:
|
117 |
+
|
118 |
+
```python
|
119 |
+
>>> from transformers import Phi3SmallConfig, Phi3SmallModel
|
120 |
+
|
121 |
+
>>> # Initializing a Phi3Small configuration
|
122 |
+
>>> configuration = Phi3SmallConfig()
|
123 |
+
|
124 |
+
>>> # Initializing a model (with random weights) from the configuration
|
125 |
+
>>> model = Phi3SmallModel(configuration)
|
126 |
+
|
127 |
+
>>> # Accessing the model configuration
|
128 |
+
>>> configuration = model.config
|
129 |
+
```
|
130 |
+
"""
|
131 |
+
|
132 |
+
model_type = "phi3small"
|
133 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
134 |
+
|
135 |
+
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
# General information about the model
|
139 |
+
vocab_size: int =100352,
|
140 |
+
max_position_embeddings: int = 8192,
|
141 |
+
# RoPE Related Parameters
|
142 |
+
rope_embedding_base: float = 10**6,
|
143 |
+
rope_position_scale: float = 1.0,
|
144 |
+
rope_scaling: Optional[Dict[str, Union[float, List[float], int]]] = None,
|
145 |
+
# General Model Parameters
|
146 |
+
hidden_size: int = 4096,
|
147 |
+
num_hidden_layers: int = 32,
|
148 |
+
# KV Shared Attention Configurations
|
149 |
+
num_attention_heads: int = 32,
|
150 |
+
num_key_value_heads: int = 8,
|
151 |
+
# GEGELU Related Parameters
|
152 |
+
hidden_act: str = "gegelu",
|
153 |
+
gegelu_limit: float = 20.0,
|
154 |
+
gegelu_pad_to_256: bool = True,
|
155 |
+
ff_dim_multiplier: Optional[int] = None,
|
156 |
+
ff_intermediate_size: Optional[int] = 14336,
|
157 |
+
# Block Sparse Attention Parameters
|
158 |
+
blocksparse_homo_head_pattern: bool = False,
|
159 |
+
blocksparse_block_size: int = 64,
|
160 |
+
blocksparse_num_local_blocks: int = 16,
|
161 |
+
blocksparse_vert_stride: int = 8,
|
162 |
+
blocksparse_triton_kernel_block_size: int = 64,
|
163 |
+
# Frequency of block-sparsity
|
164 |
+
dense_attention_every_n_layers: Optional[int] = 2,
|
165 |
+
# Reegularization parameters
|
166 |
+
embedding_dropout_prob: float =0.1,
|
167 |
+
attention_dropout_prob: float = 0.0,
|
168 |
+
ffn_dropout_prob: float = 0.1,
|
169 |
+
layer_norm_epsilon=1e-5,
|
170 |
+
initializer_range=0.02,
|
171 |
+
# MuP parameters
|
172 |
+
mup_use_scaling: bool = True,
|
173 |
+
mup_width_multiplier: bool = 8.0,
|
174 |
+
mup_embedding_multiplier: bool = 10.0,
|
175 |
+
mup_attn_multiplier: bool =1.0,
|
176 |
+
use_cache=True,
|
177 |
+
# The model does not have a bos token id
|
178 |
+
# However, in order for some of the downstream libraries to not break
|
179 |
+
# we set this to be the same as the eos_token_id
|
180 |
+
bos_token_id: int = 100257,
|
181 |
+
eos_token_id: int = 100257,
|
182 |
+
reorder_and_upcast_attn=False,
|
183 |
+
# Configuration to pad sequence length to a multiple of 64
|
184 |
+
pad_sequence_to_multiple_of_64: bool = True,
|
185 |
+
**kwargs,
|
186 |
+
):
|
187 |
+
self.vocab_size = vocab_size
|
188 |
+
self.max_position_embeddings = max_position_embeddings
|
189 |
+
self.rope_embedding_base = rope_embedding_base
|
190 |
+
self.rope_position_scale = rope_position_scale
|
191 |
+
self.rope_scaling = rope_scaling
|
192 |
+
self.hidden_size = hidden_size
|
193 |
+
# QK Shared Attention
|
194 |
+
self.num_hidden_layers = num_hidden_layers
|
195 |
+
self.num_attention_heads = num_attention_heads
|
196 |
+
self.num_key_value_heads = num_key_value_heads
|
197 |
+
# Block Sparse Attention Pattern
|
198 |
+
self.blocksparse_homo_head_pattern = blocksparse_homo_head_pattern
|
199 |
+
self.blocksparse_block_size = blocksparse_block_size
|
200 |
+
self.blocksparse_num_local_blocks = blocksparse_num_local_blocks
|
201 |
+
self.blocksparse_vert_stride = blocksparse_vert_stride
|
202 |
+
self.blocksparse_triton_kernel_block_size = blocksparse_triton_kernel_block_size
|
203 |
+
# Frequency of block sparsity
|
204 |
+
self.dense_attention_every_n_layers = dense_attention_every_n_layers
|
205 |
+
# Activation function
|
206 |
+
self.hidden_act = hidden_act
|
207 |
+
self.gegelu_limit = gegelu_limit
|
208 |
+
self.gegelu_pad_to_256 = gegelu_pad_to_256
|
209 |
+
self.ff_dim_multiplier = ff_dim_multiplier
|
210 |
+
self.ff_intermediate_size = ff_intermediate_size
|
211 |
+
if self.ff_dim_multiplier is None and self.ff_intermediate_size is None:
|
212 |
+
raise ValueError(f"Cannot have both {self.ff_dim_multiplier} and {self.ff_intermediate_size} as None")
|
213 |
+
if self.ff_dim_multiplier is not None and self.ff_intermediate_size is not None:
|
214 |
+
raise ValueError(f"Cannot specify both {self.ff_dim_multiplier} and {self.ff_intermediate_size}.")
|
215 |
+
# General regularization
|
216 |
+
self.embedding_dropout_prob = embedding_dropout_prob
|
217 |
+
self.attention_dropout_prob = attention_dropout_prob
|
218 |
+
self.ffn_dropout_prob = ffn_dropout_prob
|
219 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
220 |
+
self.initializer_range = initializer_range
|
221 |
+
# MuP parameters
|
222 |
+
self.mup_use_scaling = mup_use_scaling
|
223 |
+
self.mup_width_multiplier = mup_width_multiplier
|
224 |
+
self.mup_embedding_multiplier = mup_embedding_multiplier
|
225 |
+
self.mup_attn_multiplier = mup_attn_multiplier
|
226 |
+
self.use_cache = use_cache
|
227 |
+
|
228 |
+
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
229 |
+
self.pad_sequence_to_multiple_of_64 = pad_sequence_to_multiple_of_64
|
230 |
+
|
231 |
+
self.bos_token_id = bos_token_id
|
232 |
+
self.eos_token_id = eos_token_id
|
233 |
+
|
234 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
235 |
+
|
236 |
+
@cached_property
|
237 |
+
def dummy_token_indices(self) -> List[int]:
|
238 |
+
# Importing here to avoid circular imports
|
239 |
+
from .tokenization_phi3_small import Phi3SmallTokenizer
|
240 |
+
tokenizer = Phi3SmallTokenizer()
|
241 |
+
return tokenizer.dummy_token_indices
|
242 |
+
|
243 |
+
@property
|
244 |
+
def intermediate_size(self) -> int:
|
245 |
+
if self.ff_intermediate_size is not None:
|
246 |
+
return self.ff_intermediate_size
|
247 |
+
intermediate_size = (self.ff_dim_multiplier) * (self.hidden_size // 3) * 2
|
248 |
+
if self.gegelu_pad_to_256:
|
249 |
+
intermediate_size = next_mult(intermediate_size, 256)
|
250 |
+
return intermediate_size
|
generation_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 100257,
|
4 |
+
"eos_token_id": [
|
5 |
+
100257,
|
6 |
+
100266
|
7 |
+
],
|
8 |
+
"transformers_version": "4.38.1"
|
9 |
+
}
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a8435e8fd0cc2a302f057814bb7e2650f16a4812a9b34339e3769e213276797
|
3 |
+
size 4832943104
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0be58e1371e8630fff0f8655d6be99a2dfc6ccfb4e00bc4fa85e831b8042eac6
|
3 |
+
size 4799608224
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:77aa243e7aa0a19eb37eb8dabc6f30de9a779c606cb476e5ac432d742fe7e917
|
3 |
+
size 4799608240
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4cb5772a577868e7e794bed074c19b0d5284a5f9a0a89b537c33623873940f3a
|
3 |
+
size 352437304
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 14784548864
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
7 |
+
"model.final_layernorm.bias": "model-00004-of-00004.safetensors",
|
8 |
+
"model.final_layernorm.weight": "model-00004-of-00004.safetensors",
|
9 |
+
"model.layers.0.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
10 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
11 |
+
"model.layers.0.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
12 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
13 |
+
"model.layers.0.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
14 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
15 |
+
"model.layers.0.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
16 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
17 |
+
"model.layers.0.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
18 |
+
"model.layers.0.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
19 |
+
"model.layers.0.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
20 |
+
"model.layers.0.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
21 |
+
"model.layers.0.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
22 |
+
"model.layers.1.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
23 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
24 |
+
"model.layers.1.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
25 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
26 |
+
"model.layers.1.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
27 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
28 |
+
"model.layers.1.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
29 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
30 |
+
"model.layers.1.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
31 |
+
"model.layers.1.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
32 |
+
"model.layers.1.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
33 |
+
"model.layers.1.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
34 |
+
"model.layers.1.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
35 |
+
"model.layers.10.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
36 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
37 |
+
"model.layers.10.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
38 |
+
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
39 |
+
"model.layers.10.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
40 |
+
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
41 |
+
"model.layers.10.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
42 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
43 |
+
"model.layers.10.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
44 |
+
"model.layers.10.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
45 |
+
"model.layers.10.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
46 |
+
"model.layers.10.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
47 |
+
"model.layers.10.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
48 |
+
"model.layers.11.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
49 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
50 |
+
"model.layers.11.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
51 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
52 |
+
"model.layers.11.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
53 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
54 |
+
"model.layers.11.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
55 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
56 |
+
"model.layers.11.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
57 |
+
"model.layers.11.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
58 |
+
"model.layers.11.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
59 |
+
"model.layers.11.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
60 |
+
"model.layers.11.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
61 |
+
"model.layers.12.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
62 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
63 |
+
"model.layers.12.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
64 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
65 |
+
"model.layers.12.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
66 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
67 |
+
"model.layers.12.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
68 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
69 |
+
"model.layers.12.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
70 |
+
"model.layers.12.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
71 |
+
"model.layers.12.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
72 |
+
"model.layers.12.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
73 |
+
"model.layers.12.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
74 |
+
"model.layers.13.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
75 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
76 |
+
"model.layers.13.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
77 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
78 |
+
"model.layers.13.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
79 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
80 |
+
"model.layers.13.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
81 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
82 |
+
"model.layers.13.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
83 |
+
"model.layers.13.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
84 |
+
"model.layers.13.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
85 |
+
"model.layers.13.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
86 |
+
"model.layers.13.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
87 |
+
"model.layers.14.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
88 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
89 |
+
"model.layers.14.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
90 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
91 |
+
"model.layers.14.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
92 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
93 |
+
"model.layers.14.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
94 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
95 |
+
"model.layers.14.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
96 |
+
"model.layers.14.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
97 |
+
"model.layers.14.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
98 |
+
"model.layers.14.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
99 |
+
"model.layers.14.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
100 |
+
"model.layers.15.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
101 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
102 |
+
"model.layers.15.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
103 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
104 |
+
"model.layers.15.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
105 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
106 |
+
"model.layers.15.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
107 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
108 |
+
"model.layers.15.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
109 |
+
"model.layers.15.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
110 |
+
"model.layers.15.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
111 |
+
"model.layers.15.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
112 |
+
"model.layers.15.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
113 |
+
"model.layers.16.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
114 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
115 |
+
"model.layers.16.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
116 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
117 |
+
"model.layers.16.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
118 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
119 |
+
"model.layers.16.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
120 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
121 |
+
"model.layers.16.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
122 |
+
"model.layers.16.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
123 |
+
"model.layers.16.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
124 |
+
"model.layers.16.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
125 |
+
"model.layers.16.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
126 |
+
"model.layers.17.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
127 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
128 |
+
"model.layers.17.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
129 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
130 |
+
"model.layers.17.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
131 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
132 |
+
"model.layers.17.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
133 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
134 |
+
"model.layers.17.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
135 |
+
"model.layers.17.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
136 |
+
"model.layers.17.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
137 |
+
"model.layers.17.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
138 |
+
"model.layers.17.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
139 |
+
"model.layers.18.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
140 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
141 |
+
"model.layers.18.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
142 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
143 |
+
"model.layers.18.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
144 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
145 |
+
"model.layers.18.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
146 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
147 |
+
"model.layers.18.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
148 |
+
"model.layers.18.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
149 |
+
"model.layers.18.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
150 |
+
"model.layers.18.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
151 |
+
"model.layers.18.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
152 |
+
"model.layers.19.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
153 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
154 |
+
"model.layers.19.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
155 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
156 |
+
"model.layers.19.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
157 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
158 |
+
"model.layers.19.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
159 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
160 |
+
"model.layers.19.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
161 |
+
"model.layers.19.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
162 |
+
"model.layers.19.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
163 |
+
"model.layers.19.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
164 |
+
"model.layers.19.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
165 |
+
"model.layers.2.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
166 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
167 |
+
"model.layers.2.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
168 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
169 |
+
"model.layers.2.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
170 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
171 |
+
"model.layers.2.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
172 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
173 |
+
"model.layers.2.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
174 |
+
"model.layers.2.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
175 |
+
"model.layers.2.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
176 |
+
"model.layers.2.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
177 |
+
"model.layers.2.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
178 |
+
"model.layers.20.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
179 |
+
"model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
180 |
+
"model.layers.20.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
181 |
+
"model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
182 |
+
"model.layers.20.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
183 |
+
"model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
184 |
+
"model.layers.20.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
185 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
186 |
+
"model.layers.20.self_attn.dense.bias": "model-00002-of-00004.safetensors",
|
187 |
+
"model.layers.20.self_attn.dense.weight": "model-00002-of-00004.safetensors",
|
188 |
+
"model.layers.20.self_attn.query_key_value.bias": "model-00002-of-00004.safetensors",
|
189 |
+
"model.layers.20.self_attn.query_key_value.weight": "model-00002-of-00004.safetensors",
|
190 |
+
"model.layers.20.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
191 |
+
"model.layers.21.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
192 |
+
"model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
193 |
+
"model.layers.21.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
194 |
+
"model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
195 |
+
"model.layers.21.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
196 |
+
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
197 |
+
"model.layers.21.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
198 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
199 |
+
"model.layers.21.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
200 |
+
"model.layers.21.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
201 |
+
"model.layers.21.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
202 |
+
"model.layers.21.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
203 |
+
"model.layers.21.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
204 |
+
"model.layers.22.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
205 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
206 |
+
"model.layers.22.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
207 |
+
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
208 |
+
"model.layers.22.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
209 |
+
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
210 |
+
"model.layers.22.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
211 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
212 |
+
"model.layers.22.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
213 |
+
"model.layers.22.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
214 |
+
"model.layers.22.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
215 |
+
"model.layers.22.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
216 |
+
"model.layers.22.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
217 |
+
"model.layers.23.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
218 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
219 |
+
"model.layers.23.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
220 |
+
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
221 |
+
"model.layers.23.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
222 |
+
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
223 |
+
"model.layers.23.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
224 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
225 |
+
"model.layers.23.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
226 |
+
"model.layers.23.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
227 |
+
"model.layers.23.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
228 |
+
"model.layers.23.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
229 |
+
"model.layers.23.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
230 |
+
"model.layers.24.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
231 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
232 |
+
"model.layers.24.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
233 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
234 |
+
"model.layers.24.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
235 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
236 |
+
"model.layers.24.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
237 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
238 |
+
"model.layers.24.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
239 |
+
"model.layers.24.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
240 |
+
"model.layers.24.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
241 |
+
"model.layers.24.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
242 |
+
"model.layers.24.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
243 |
+
"model.layers.25.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
244 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
245 |
+
"model.layers.25.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
246 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
247 |
+
"model.layers.25.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
248 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
249 |
+
"model.layers.25.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
250 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
251 |
+
"model.layers.25.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
252 |
+
"model.layers.25.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
253 |
+
"model.layers.25.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
254 |
+
"model.layers.25.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
255 |
+
"model.layers.25.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
256 |
+
"model.layers.26.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
257 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
258 |
+
"model.layers.26.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
259 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
260 |
+
"model.layers.26.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
261 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
262 |
+
"model.layers.26.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
263 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
264 |
+
"model.layers.26.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
265 |
+
"model.layers.26.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
266 |
+
"model.layers.26.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
267 |
+
"model.layers.26.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
268 |
+
"model.layers.26.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
269 |
+
"model.layers.27.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
270 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
271 |
+
"model.layers.27.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
272 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
273 |
+
"model.layers.27.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
274 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
275 |
+
"model.layers.27.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
276 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
277 |
+
"model.layers.27.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
278 |
+
"model.layers.27.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
279 |
+
"model.layers.27.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
280 |
+
"model.layers.27.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
281 |
+
"model.layers.27.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
282 |
+
"model.layers.28.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
283 |
+
"model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
284 |
+
"model.layers.28.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
285 |
+
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
286 |
+
"model.layers.28.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
287 |
+
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
288 |
+
"model.layers.28.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
289 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
290 |
+
"model.layers.28.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
291 |
+
"model.layers.28.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
292 |
+
"model.layers.28.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
293 |
+
"model.layers.28.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
294 |
+
"model.layers.28.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
295 |
+
"model.layers.29.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
296 |
+
"model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
297 |
+
"model.layers.29.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
298 |
+
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
299 |
+
"model.layers.29.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
300 |
+
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
301 |
+
"model.layers.29.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
302 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
303 |
+
"model.layers.29.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
304 |
+
"model.layers.29.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
305 |
+
"model.layers.29.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
306 |
+
"model.layers.29.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
307 |
+
"model.layers.29.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
308 |
+
"model.layers.3.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
309 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
310 |
+
"model.layers.3.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
311 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
312 |
+
"model.layers.3.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
313 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
314 |
+
"model.layers.3.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
315 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
316 |
+
"model.layers.3.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
317 |
+
"model.layers.3.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
318 |
+
"model.layers.3.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
319 |
+
"model.layers.3.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
320 |
+
"model.layers.3.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
321 |
+
"model.layers.30.input_layernorm.bias": "model-00003-of-00004.safetensors",
|
322 |
+
"model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
323 |
+
"model.layers.30.mlp.down_proj.bias": "model-00003-of-00004.safetensors",
|
324 |
+
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
325 |
+
"model.layers.30.mlp.up_proj.bias": "model-00003-of-00004.safetensors",
|
326 |
+
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
327 |
+
"model.layers.30.post_attention_layernorm.bias": "model-00003-of-00004.safetensors",
|
328 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
329 |
+
"model.layers.30.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
330 |
+
"model.layers.30.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
331 |
+
"model.layers.30.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
332 |
+
"model.layers.30.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
333 |
+
"model.layers.30.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
334 |
+
"model.layers.31.input_layernorm.bias": "model-00004-of-00004.safetensors",
|
335 |
+
"model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
336 |
+
"model.layers.31.mlp.down_proj.bias": "model-00004-of-00004.safetensors",
|
337 |
+
"model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
338 |
+
"model.layers.31.mlp.up_proj.bias": "model-00004-of-00004.safetensors",
|
339 |
+
"model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
340 |
+
"model.layers.31.post_attention_layernorm.bias": "model-00004-of-00004.safetensors",
|
341 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
342 |
+
"model.layers.31.self_attn.dense.bias": "model-00003-of-00004.safetensors",
|
343 |
+
"model.layers.31.self_attn.dense.weight": "model-00003-of-00004.safetensors",
|
344 |
+
"model.layers.31.self_attn.query_key_value.bias": "model-00003-of-00004.safetensors",
|
345 |
+
"model.layers.31.self_attn.query_key_value.weight": "model-00003-of-00004.safetensors",
|
346 |
+
"model.layers.31.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
347 |
+
"model.layers.4.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
348 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
349 |
+
"model.layers.4.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
350 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
351 |
+
"model.layers.4.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
352 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
353 |
+
"model.layers.4.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
354 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
355 |
+
"model.layers.4.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
356 |
+
"model.layers.4.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
357 |
+
"model.layers.4.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
358 |
+
"model.layers.4.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
359 |
+
"model.layers.4.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
360 |
+
"model.layers.5.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
361 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
362 |
+
"model.layers.5.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
363 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
364 |
+
"model.layers.5.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
365 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
366 |
+
"model.layers.5.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
367 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
368 |
+
"model.layers.5.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
369 |
+
"model.layers.5.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
370 |
+
"model.layers.5.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
371 |
+
"model.layers.5.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
372 |
+
"model.layers.5.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
373 |
+
"model.layers.6.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
374 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
375 |
+
"model.layers.6.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
376 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
377 |
+
"model.layers.6.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
378 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
379 |
+
"model.layers.6.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
380 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
381 |
+
"model.layers.6.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
382 |
+
"model.layers.6.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
383 |
+
"model.layers.6.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
384 |
+
"model.layers.6.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
385 |
+
"model.layers.6.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
386 |
+
"model.layers.7.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
387 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
388 |
+
"model.layers.7.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
389 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
390 |
+
"model.layers.7.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
391 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
392 |
+
"model.layers.7.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
393 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
394 |
+
"model.layers.7.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
395 |
+
"model.layers.7.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
396 |
+
"model.layers.7.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
397 |
+
"model.layers.7.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
398 |
+
"model.layers.7.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
399 |
+
"model.layers.8.input_layernorm.bias": "model-00001-of-00004.safetensors",
|
400 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
401 |
+
"model.layers.8.mlp.down_proj.bias": "model-00001-of-00004.safetensors",
|
402 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
403 |
+
"model.layers.8.mlp.up_proj.bias": "model-00001-of-00004.safetensors",
|
404 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
405 |
+
"model.layers.8.post_attention_layernorm.bias": "model-00001-of-00004.safetensors",
|
406 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
407 |
+
"model.layers.8.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
408 |
+
"model.layers.8.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
409 |
+
"model.layers.8.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
410 |
+
"model.layers.8.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
411 |
+
"model.layers.8.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
412 |
+
"model.layers.9.input_layernorm.bias": "model-00002-of-00004.safetensors",
|
413 |
+
"model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
414 |
+
"model.layers.9.mlp.down_proj.bias": "model-00002-of-00004.safetensors",
|
415 |
+
"model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
416 |
+
"model.layers.9.mlp.up_proj.bias": "model-00002-of-00004.safetensors",
|
417 |
+
"model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
418 |
+
"model.layers.9.post_attention_layernorm.bias": "model-00002-of-00004.safetensors",
|
419 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
420 |
+
"model.layers.9.self_attn.dense.bias": "model-00001-of-00004.safetensors",
|
421 |
+
"model.layers.9.self_attn.dense.weight": "model-00001-of-00004.safetensors",
|
422 |
+
"model.layers.9.self_attn.query_key_value.bias": "model-00001-of-00004.safetensors",
|
423 |
+
"model.layers.9.self_attn.query_key_value.weight": "model-00001-of-00004.safetensors",
|
424 |
+
"model.layers.9.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors"
|
425 |
+
}
|
426 |
+
}
|
modeling_phi3_small.py
ADDED
@@ -0,0 +1,1140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Dict, Optional, List, Tuple, Union
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
from transformers.modeling_outputs import SequenceClassifierOutputWithPast, CausalLMOutputWithPast, BaseModelOutputWithPast
|
11 |
+
from transformers.modeling_utils import PreTrainedModel
|
12 |
+
from transformers.utils import logging
|
13 |
+
|
14 |
+
from transformers.cache_utils import Cache, DynamicCache
|
15 |
+
|
16 |
+
from .triton_flash_blocksparse_attn import BlockSparseParams
|
17 |
+
from .triton_blocksparse_attention_layer import BlockSparseAttentionLayer
|
18 |
+
from .positional_embedding import RotaryEmbedding
|
19 |
+
|
20 |
+
from .configuration_phi3_small import Phi3SmallConfig
|
21 |
+
|
22 |
+
# Flash Attention Related Imports
|
23 |
+
is_flash_attention_available = False
|
24 |
+
try:
|
25 |
+
import flash_attn
|
26 |
+
if int(flash_attn.__version__.split('.')[0]) < 2:
|
27 |
+
from flash_attn.flash_attn_interface import (
|
28 |
+
flash_attn_func,
|
29 |
+
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
30 |
+
)
|
31 |
+
|
32 |
+
# rename `max_seqlen`
|
33 |
+
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, **kwargs):
|
34 |
+
return flash_attn_func(qkv, cu_seqlens, dropout_p=dropout_p, max_s=max_seqlen, **kwargs)
|
35 |
+
|
36 |
+
else:
|
37 |
+
from flash_attn.flash_attn_interface import (
|
38 |
+
flash_attn_varlen_kvpacked_func,
|
39 |
+
)
|
40 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
41 |
+
is_flash_attention_available = True
|
42 |
+
except ImportError:
|
43 |
+
pass
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__)
|
46 |
+
|
47 |
+
LegacyCache = Tuple[Tuple[torch.FloatTensor]]
|
48 |
+
|
49 |
+
# Taken from https://github.com/allenai/allennlp/blob/main/allennlp/nn/util.py
|
50 |
+
def info_value_of_dtype(dtype: torch.dtype):
|
51 |
+
"""
|
52 |
+
Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool.
|
53 |
+
"""
|
54 |
+
if dtype == torch.bool:
|
55 |
+
raise TypeError("Does not support torch.bool")
|
56 |
+
elif dtype.is_floating_point:
|
57 |
+
return torch.finfo(dtype)
|
58 |
+
else:
|
59 |
+
return torch.iinfo(dtype)
|
60 |
+
|
61 |
+
|
62 |
+
# Taken from https://github.com/allenai/allennlp/blob/main/allennlp/nn/util.py
|
63 |
+
def min_value_of_dtype(dtype: torch.dtype):
|
64 |
+
"""
|
65 |
+
Returns the minimum value of a given PyTorch data type. Does not allow torch.bool.
|
66 |
+
"""
|
67 |
+
return info_value_of_dtype(dtype).min
|
68 |
+
|
69 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
70 |
+
def _get_unpad_data(attention_mask):
|
71 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
72 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
73 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
74 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
75 |
+
return (
|
76 |
+
indices,
|
77 |
+
cu_seqlens,
|
78 |
+
max_seqlen_in_batch,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
@torch.jit.script
|
83 |
+
def quick_gelu(x):
|
84 |
+
return x * torch.sigmoid(1.702 * x)
|
85 |
+
|
86 |
+
|
87 |
+
@torch.jit.script
|
88 |
+
def gegelu(input, limit: Optional[float] = None):
|
89 |
+
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
90 |
+
if limit is not None:
|
91 |
+
a_gelu = torch.where(
|
92 |
+
torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
|
93 |
+
)
|
94 |
+
a_linear = torch.where(
|
95 |
+
torch.isinf(a_linear), a_linear, a_linear.clamp(min=-limit, max=limit)
|
96 |
+
)
|
97 |
+
out_gelu = quick_gelu(a_gelu)
|
98 |
+
return out_gelu * (a_linear + 1)
|
99 |
+
|
100 |
+
def collapse_first_n_dims(x: torch.Tensor, n: int) -> torch.Tensor:
|
101 |
+
"""
|
102 |
+
Collapse the first `n` dimensions of a tensor into a single dimension.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
x (torch.Tensor): The input tensor.
|
106 |
+
n (int): The number of dimensions to collapse.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
torch.Tensor: The output tensor.
|
110 |
+
"""
|
111 |
+
return x.view(-1, *x.shape[n:])
|
112 |
+
|
113 |
+
def pad_tensor_to_next_mult_of(
|
114 |
+
tensor: torch.Tensor,
|
115 |
+
dim: int,
|
116 |
+
n: int,
|
117 |
+
) -> Tuple[torch.Tensor, int]:
|
118 |
+
"""
|
119 |
+
Pads a tensor along a specified dimension to the next multiple of a given number.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
tensor (torch.Tensor): The input tensor.
|
123 |
+
dim (int): The dimension along which to pad the tensor.
|
124 |
+
n (int): The number to pad the tensor to the next multiple of.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Tuple[torch.Tensor, int]: A tuple containing the padded tensor and the amount of padding added.
|
128 |
+
"""
|
129 |
+
residual = tensor.size(dim) % n
|
130 |
+
if residual == 0:
|
131 |
+
return tensor, 0
|
132 |
+
padding = n - residual
|
133 |
+
padding_tensor = torch.zeros((*tensor.size()[:dim], padding, *tensor.size()[dim + 1:]), device=tensor.device, dtype=tensor.dtype)
|
134 |
+
return torch.cat([tensor, padding_tensor], dim=dim), padding
|
135 |
+
|
136 |
+
def strip_padding_from_tensor(
|
137 |
+
tensor: torch.Tensor,
|
138 |
+
dim: int,
|
139 |
+
residual: int,
|
140 |
+
) -> torch.Tensor:
|
141 |
+
"""
|
142 |
+
Removes padding from a tensor along a specified dimension.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
tensor (torch.Tensor): The input tensor.
|
146 |
+
dim (int): The dimension along which to remove padding.
|
147 |
+
residual (int): The amount of padding to remove.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
torch.Tensor: The tensor with padding removed along the specified dimension.
|
151 |
+
"""
|
152 |
+
return torch.narrow(tensor, dim, 0, tensor.size(dim) - residual)
|
153 |
+
|
154 |
+
class Phi3SmallMLP(nn.Module):
|
155 |
+
def __init__(self, config: Phi3SmallConfig):
|
156 |
+
super().__init__()
|
157 |
+
self.config = config
|
158 |
+
assert self.config.hidden_act == "gegelu", "Only `gegelu` is supported for the Phi-3-small model .."
|
159 |
+
self.hidden_size = config.hidden_size
|
160 |
+
self.gegelu_limit = config.gegelu_limit
|
161 |
+
self.intermediate_size = config.intermediate_size
|
162 |
+
|
163 |
+
self.up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size)
|
164 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
|
165 |
+
self.dropout = nn.Dropout(config.ffn_dropout_prob)
|
166 |
+
|
167 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
168 |
+
return self.dropout(
|
169 |
+
self.down_proj(
|
170 |
+
gegelu(self.up_proj(x), limit=self.gegelu_limit)
|
171 |
+
)
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
class Phi3SmallSelfAttention(nn.Module):
|
176 |
+
def __init__(self, config: Phi3SmallConfig, layer_idx: Optional[int] = None) -> None:
|
177 |
+
super().__init__()
|
178 |
+
self.config = config
|
179 |
+
self.layer_idx = layer_idx
|
180 |
+
if layer_idx is None:
|
181 |
+
logger.warning_once(
|
182 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
183 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
184 |
+
"when creating this class."
|
185 |
+
)
|
186 |
+
|
187 |
+
self.hidden_size = config.hidden_size
|
188 |
+
# Number of Query Heads
|
189 |
+
self.num_heads = config.num_attention_heads
|
190 |
+
self.head_dim = self.hidden_size // self.num_heads
|
191 |
+
# Number of Key Value Heads
|
192 |
+
self.num_key_value_heads = config.num_key_value_heads
|
193 |
+
self.num_q_per_kv = self.num_heads // self.num_key_value_heads
|
194 |
+
self.max_position_embeddings = config.max_position_embeddings
|
195 |
+
self.rope_embedding_base = config.rope_embedding_base
|
196 |
+
self.rope_position_scale = config.rope_position_scale
|
197 |
+
self.is_causal = True
|
198 |
+
|
199 |
+
self.attention_dropout_rate = config.attention_dropout_prob
|
200 |
+
|
201 |
+
norm_factor = None
|
202 |
+
if config.mup_use_scaling:
|
203 |
+
norm_factor = self.head_dim / config.mup_attn_multiplier
|
204 |
+
else:
|
205 |
+
norm_factor = math.sqrt(self.head_dim)
|
206 |
+
self.softmax_scale = 1.0 / norm_factor
|
207 |
+
|
208 |
+
self.query_key_value = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim)
|
209 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
210 |
+
|
211 |
+
self.blocksparse_params = None
|
212 |
+
# layer_idx is 0 indexed because that's what the KV Cache expects.
|
213 |
+
if self.config.dense_attention_every_n_layers and ((self.layer_idx + 1) % self.config.dense_attention_every_n_layers == 0):
|
214 |
+
logger.info(
|
215 |
+
f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
|
216 |
+
f"{self.config.dense_attention_every_n_layers}"
|
217 |
+
)
|
218 |
+
assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
|
219 |
+
else:
|
220 |
+
# BlockSparse related Parameters
|
221 |
+
self.blocksparse_params = BlockSparseParams.from_config(config)
|
222 |
+
|
223 |
+
if self.blocksparse:
|
224 |
+
active_head_range = None
|
225 |
+
"""
|
226 |
+
... note(bapatra)::
|
227 |
+
|
228 |
+
In case of tensor parallelism and while using the heterogeneous head patterns,
|
229 |
+
the active head range needs to be modified based on the tensor parallel rank
|
230 |
+
and the tensor parallel world size.
|
231 |
+
|
232 |
+
This is because in the case of heterogeneous head patterns, the kernel needs to know
|
233 |
+
which head is on which device, so that it can pick the corresponding blocksparse head
|
234 |
+
pattern correctly.
|
235 |
+
|
236 |
+
Example:
|
237 |
+
```python
|
238 |
+
|
239 |
+
if not self.blocksparse_params.homo_head_pattern:
|
240 |
+
tp_rank = torch.distributed.get_rank() % tp_world_size
|
241 |
+
num_heads_per_partition = num_heads // tp_world_size
|
242 |
+
active_head_range = (tp_rank * num_heads_per_partition, (tp_rank + 1) * num_heads_per_partition)
|
243 |
+
|
244 |
+
```
|
245 |
+
|
246 |
+
"""
|
247 |
+
|
248 |
+
self._blocksparse_layer = BlockSparseAttentionLayer(
|
249 |
+
n_heads=self.num_heads,
|
250 |
+
max_seq_len=self.max_position_embeddings,
|
251 |
+
sparse_block_size=self.blocksparse_params.block_size,
|
252 |
+
local_blocks=self.blocksparse_params.num_local_blocks,
|
253 |
+
vert_stride=self.blocksparse_params.vert_stride,
|
254 |
+
kernel_block_size=self.blocksparse_params.kernel_block_size,
|
255 |
+
homo_head=self.blocksparse_params.homo_head_pattern,
|
256 |
+
active_head_range=active_head_range,
|
257 |
+
)
|
258 |
+
self.rotary_emb = RotaryEmbedding.from_config(config)
|
259 |
+
|
260 |
+
|
261 |
+
@property
|
262 |
+
def blocksparse(self):
|
263 |
+
return self.blocksparse_params is not None
|
264 |
+
|
265 |
+
def _split_heads(self, mixed_x_layer: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
266 |
+
bs, sq, _ = mixed_x_layer.size()
|
267 |
+
r"""
|
268 |
+
The main idea is that we group tensors as
|
269 |
+
[bs, sq, (q00, q01, ... q0m, k0, v0), (q10, q11, ... q1m, k1, v1), ... (qn0, qn1, ... qnm, kn, vn)]
|
270 |
+
That ways, when the MP column sharding happens, this tensor will be sharded keeping all the
|
271 |
+
queries and keys intact. In order to get the correct qkv, we first break into groups, and then
|
272 |
+
index into the groups.
|
273 |
+
"""
|
274 |
+
|
275 |
+
intermediate_shape = (bs, sq, -1, (self.num_q_per_kv + 2), self.head_dim)
|
276 |
+
mixed_x_layer = mixed_x_layer.view(*intermediate_shape)
|
277 |
+
q = mixed_x_layer[:, :, :, :-2]
|
278 |
+
k = mixed_x_layer[:, :, :, [-2]]
|
279 |
+
v = mixed_x_layer[:, :, :, [-1]]
|
280 |
+
q, k, v = [
|
281 |
+
rearrange(
|
282 |
+
x,
|
283 |
+
"bs sq group nh hn -> bs sq (group nh) hn"
|
284 |
+
) for x in (q, k, v)
|
285 |
+
]
|
286 |
+
return q, k, v
|
287 |
+
|
288 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._unpad_input
|
289 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
290 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
291 |
+
|
292 |
+
|
293 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
294 |
+
|
295 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
296 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
297 |
+
|
298 |
+
if query_length == kv_seq_len:
|
299 |
+
query_layer = index_first_axis(
|
300 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
301 |
+
)
|
302 |
+
cu_seqlens_q = cu_seqlens_k
|
303 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
304 |
+
indices_q = indices_k
|
305 |
+
elif query_length == 1:
|
306 |
+
max_seqlen_in_batch_q = 1
|
307 |
+
cu_seqlens_q = torch.arange(
|
308 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
309 |
+
) # There is a memcpy here, that is very bad.
|
310 |
+
indices_q = cu_seqlens_q[:-1]
|
311 |
+
query_layer = query_layer.squeeze(1)
|
312 |
+
else:
|
313 |
+
# The -q_len: slice assumes left padding.
|
314 |
+
attention_mask = attention_mask[:, -query_length:]
|
315 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
316 |
+
|
317 |
+
return (
|
318 |
+
query_layer,
|
319 |
+
key_layer,
|
320 |
+
value_layer,
|
321 |
+
indices_q,
|
322 |
+
(cu_seqlens_q, cu_seqlens_k),
|
323 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
324 |
+
)
|
325 |
+
|
326 |
+
def _apply_blocksparse_attention(
|
327 |
+
self,
|
328 |
+
q: torch.Tensor,
|
329 |
+
k: torch.Tensor,
|
330 |
+
v: torch.Tensor,
|
331 |
+
attention_mask: Optional[torch.LongTensor],
|
332 |
+
return_attention_probs: bool = False,
|
333 |
+
) -> torch.Tensor:
|
334 |
+
"""
|
335 |
+
Applies blocksparse attention to the input tensors.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
q (torch.Tensor): The query tensor of shape (bs, nqp, seq_len, hn).
|
339 |
+
k (torch.Tensor): The key tensor of shape (bs, nkp, seq_len, hn).
|
340 |
+
v (torch.Tensor): The value tensor of shape (bs, nkp, seq_len, hn).
|
341 |
+
attention_mask (Optional[torch.LongTensor]): The attention mask tensor of shape (bs, seq_len).
|
342 |
+
return_attention_probs (bool, optional): Whether to return attention probabilities. Defaults to False.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
torch.Tensor: The context layer tensor of shape (bs, nqp, seq_len, hn).
|
346 |
+
"""
|
347 |
+
assert not return_attention_probs, "return_attention_probs is not supported for blocksparse attention"
|
348 |
+
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
349 |
+
# shape: (bs, nqp, seq_len, hn)
|
350 |
+
if torch.is_grad_enabled():
|
351 |
+
# Training or non-batched inference
|
352 |
+
context_layer = self._blocksparse_layer(
|
353 |
+
q=q, k=k, v=v, sm_scale=self.softmax_scale
|
354 |
+
)
|
355 |
+
elif attention_mask is None:
|
356 |
+
if q.size(0) != 1:
|
357 |
+
logger.warning_once(
|
358 |
+
"You are attempting to do batched inference without passing the attention mask.\n"
|
359 |
+
"This is okay if you are running loglikelihood requests. However, if you want to do generation, "
|
360 |
+
"this probably won't work as expected. Please pass the attention mask to the forward function."
|
361 |
+
)
|
362 |
+
context_layer = self._blocksparse_layer(
|
363 |
+
q=q, k=k, v=v, sm_scale=self.softmax_scale
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
"""
|
367 |
+
Shapes of tensors are as follows:
|
368 |
+
q: (bs, nqp, seq_len, hdim)
|
369 |
+
k: (bs, nkp, seq_len, hdim)
|
370 |
+
v: (bs, nkp, seq_len, hdim)
|
371 |
+
We first need to transpose the shapes to fit what the
|
372 |
+
kernel needs, and the reinvert it back at the end of the operations
|
373 |
+
"""
|
374 |
+
assert attention_mask.ndim == 2, "The kernel, like flash-attention-2, only supports 2d attention masks ..."
|
375 |
+
left_paddings = attention_mask.shape[1] - attention_mask.sum(dim=-1)
|
376 |
+
# shape: (bs, seq_len, nqp, hdim)
|
377 |
+
q = q.transpose(1, 2).contiguous()
|
378 |
+
# shape: (bs, seq_len, nkp, hdim)
|
379 |
+
k = k.transpose(1, 2).contiguous()
|
380 |
+
# shape: (bs, seq_len, nkp, hdim)
|
381 |
+
v = v.transpose(1, 2).contiguous()
|
382 |
+
context_layer = self._blocksparse_layer(
|
383 |
+
q=q, k=k, v=v, sm_scale=self.softmax_scale, left_paddings=left_paddings.to(torch.int32)
|
384 |
+
)
|
385 |
+
# shape: (bs, nqp, seq_len, hdim)
|
386 |
+
context_layer = context_layer.transpose(1, 2).contiguous()
|
387 |
+
return context_layer
|
388 |
+
|
389 |
+
def _apply_dense_attention(
|
390 |
+
self,
|
391 |
+
q: torch.Tensor,
|
392 |
+
k: torch.Tensor,
|
393 |
+
v: torch.Tensor,
|
394 |
+
attention_mask: torch.Tensor,
|
395 |
+
return_attention_probs: bool = False,
|
396 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
397 |
+
"""
|
398 |
+
Apply dense attention
|
399 |
+
|
400 |
+
Args:
|
401 |
+
q (torch.Tensor):
|
402 |
+
The query tensor, shape: (bs, num_query_heads, seq_len, head_size)
|
403 |
+
k (torch.Tensor):
|
404 |
+
The key tensor, shape: (bs, num_query_heads, seq_len, head_size)
|
405 |
+
v (torch.Tensor):
|
406 |
+
The value tensor, shape: (bs, num_query_heads, seq_len, head_size)
|
407 |
+
|
408 |
+
return_attention_probs (bool, optional):
|
409 |
+
Return the attention probabilities. Defaults to False.
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
413 |
+
Return the output of the attention aggregation. If `return_attention_probs` is True, then
|
414 |
+
also return the attention probabilities
|
415 |
+
|
416 |
+
.. note::
|
417 |
+
Right now, am assuming the expansion for the query key values is already done
|
418 |
+
outside. But ideally, since Flash attention handles the GQA correctly, we can
|
419 |
+
avoid doing that.
|
420 |
+
|
421 |
+
"""
|
422 |
+
attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
|
423 |
+
# Get into the correct shape for the Flash Attention API
|
424 |
+
# shape: (bs, seq_len, nqp, hn)
|
425 |
+
q = q.transpose(1, 2).contiguous()
|
426 |
+
query_length = q.size(1)
|
427 |
+
# shape: (bs, seq_len, npq, hn)
|
428 |
+
k = k.transpose(1, 2).contiguous()
|
429 |
+
# shape: (bs, seq_len, npq, hn)
|
430 |
+
v = v.transpose(1, 2).contiguous()
|
431 |
+
|
432 |
+
if attention_mask is not None:
|
433 |
+
causal = q.size(2) == k.size(2)
|
434 |
+
batch_size = q.shape[0]
|
435 |
+
flat_q, flat_k, flat_v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
436 |
+
q, k, v, attention_mask, query_length
|
437 |
+
)
|
438 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
439 |
+
max_seqlen_q, max_seqlen_k = max_seq_lens
|
440 |
+
flat_kv = torch.cat((flat_k.unsqueeze(1), flat_v.unsqueeze(1)), dim=1)
|
441 |
+
attn_output_unpad = flash_attn_varlen_kvpacked_func(
|
442 |
+
q=flat_q,
|
443 |
+
kv=flat_kv,
|
444 |
+
cu_seqlens_q=cu_seqlens_q,
|
445 |
+
cu_seqlens_k=cu_seqlens_k,
|
446 |
+
max_seqlen_q=max_seqlen_q,
|
447 |
+
max_seqlen_k=max_seqlen_k,
|
448 |
+
dropout_p=attention_dropout_prob,
|
449 |
+
softmax_scale=self.softmax_scale,
|
450 |
+
causal=causal,
|
451 |
+
return_attn_probs=return_attention_probs
|
452 |
+
)
|
453 |
+
attention_output = pad_input(
|
454 |
+
attn_output_unpad, indices_q, batch_size, query_length
|
455 |
+
)
|
456 |
+
else:
|
457 |
+
kv = torch.cat((k.unsqueeze(2), v.unsqueeze(2)), dim=2)
|
458 |
+
cu_seqlens_q = torch.arange(
|
459 |
+
0, (q.size(0) + 1), device=q.device, dtype=torch.int32
|
460 |
+
) * q.size(1)
|
461 |
+
cu_seqlens_kv = torch.arange(
|
462 |
+
0, (kv.size(0) + 1), device=kv.device, dtype=torch.int32
|
463 |
+
) * kv.size(1)
|
464 |
+
max_seqlen_q = q.size(1)
|
465 |
+
max_seqlen_k = kv.size(1)
|
466 |
+
attention_output = flash_attn_varlen_kvpacked_func(
|
467 |
+
q=collapse_first_n_dims(q, 2),
|
468 |
+
kv=collapse_first_n_dims(kv, 2),
|
469 |
+
cu_seqlens_q=cu_seqlens_q,
|
470 |
+
cu_seqlens_k=cu_seqlens_kv,
|
471 |
+
max_seqlen_q=max_seqlen_q,
|
472 |
+
max_seqlen_k=max_seqlen_k,
|
473 |
+
dropout_p=attention_dropout_prob,
|
474 |
+
softmax_scale=self.softmax_scale,
|
475 |
+
causal=q.size(1) == kv.size(1),
|
476 |
+
return_attn_probs=return_attention_probs
|
477 |
+
)
|
478 |
+
if return_attention_probs:
|
479 |
+
(context_layer, attn_probs) = attention_output
|
480 |
+
context_layer = context_layer.view(q.size(0), q.size(1), -1, q.size(3)).transpose(1, 2).contiguous()
|
481 |
+
return (context_layer, attn_probs)
|
482 |
+
context_layer = attention_output
|
483 |
+
context_layer = context_layer.view(q.size(0), q.size(1), -1, q.size(3)).transpose(1, 2).contiguous()
|
484 |
+
return context_layer
|
485 |
+
|
486 |
+
|
487 |
+
def expand_kv_to_q_size(self, kv: torch.Tensor, num_q_per_kv: int) -> torch.Tensor:
|
488 |
+
"""
|
489 |
+
Expand the key-value tensor to match the size of the query tensor.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
kv (torch.Tensor): The key-value tensor of shape (bsz, nkp, 2, seq_len, hdim).
|
493 |
+
num_q_per_kv (int): The number of queries per key-value.
|
494 |
+
|
495 |
+
Returns:
|
496 |
+
torch.Tensor: The expanded key-value tensor of shape (bsz, nqp, 2, seq_len, hdim).
|
497 |
+
Where nqp = num_q_per_kv * nkp
|
498 |
+
|
499 |
+
.. note(bapatra)::
|
500 |
+
Right now, I am using a repeat_interleave to expand the kv to the size of q.
|
501 |
+
This incurs a memory penalty, since the tensors are actually copied.
|
502 |
+
TODO: If this does yield benefits, then potentially we can use the re-written
|
503 |
+
flash attention kernel that can handle GQA.
|
504 |
+
"""
|
505 |
+
|
506 |
+
repeats = torch.tensor([num_q_per_kv] * kv.size(1)).to(kv.device)
|
507 |
+
total = repeats.sum()
|
508 |
+
expanded_kv = torch.repeat_interleave(
|
509 |
+
kv,
|
510 |
+
repeats=repeats,
|
511 |
+
dim=1,
|
512 |
+
output_size=total
|
513 |
+
)
|
514 |
+
return expanded_kv
|
515 |
+
|
516 |
+
def forward(
|
517 |
+
self,
|
518 |
+
hidden_states: torch.Tensor,
|
519 |
+
attention_mask: Optional[torch.Tensor] = None,
|
520 |
+
position_ids: Optional[torch.LongTensor] = None,
|
521 |
+
past_key_values: Optional[Cache] = None,
|
522 |
+
output_attentions: bool = False,
|
523 |
+
use_cache: bool = False,
|
524 |
+
**kwargs,
|
525 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
526 |
+
"""
|
527 |
+
The forward function of the Self Attention Layer.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
hidden_states (torch.Tensor):
|
531 |
+
The input tensor of shape (bs, q_len, h).
|
532 |
+
attention_mask (Optional[torch.Tensor], optional):
|
533 |
+
The attention mask tensor of shape (bs, seq_len). This is the 2D attention mask tensor as is standard in the flash-attention
|
534 |
+
kernel.
|
535 |
+
Defaults to None.
|
536 |
+
position_ids (Optional[torch.LongTensor], optional):
|
537 |
+
The position ids tensor of shape (bs, q_len). Defaults to None. Unused by the function.
|
538 |
+
past_key_value (Optional[Cache], optional):
|
539 |
+
The previous kv cache values. Defaults to None.
|
540 |
+
output_attentions (bool, optional):
|
541 |
+
Whether to return the attention scores. Defaults to False.
|
542 |
+
.. note::
|
543 |
+
For the blocksparse attention kernel, we do not support returning the attention scores.
|
544 |
+
use_cache (bool, optional):
|
545 |
+
Whether to use the cache for storing the kv. Defaults to False.
|
546 |
+
|
547 |
+
Returns:
|
548 |
+
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
549 |
+
The output tensor of shape (bs, q_len, h),
|
550 |
+
the attention scores tensor of shape (bs, nqp, q_len, seq_len) if `output_attentions` is True,
|
551 |
+
and the updated cache values if `use_cache` is True.
|
552 |
+
|
553 |
+
Notations:
|
554 |
+
------------
|
555 |
+
bs: batch size
|
556 |
+
sq_len: sequence length of the entire sequence
|
557 |
+
q_len: sequence length of the query
|
558 |
+
cache_sq: sequence length in the cache
|
559 |
+
If there is no cache then cache_sq = 0
|
560 |
+
and sq_len = q_len
|
561 |
+
otherwise sq_len = q_len + cache_sq
|
562 |
+
h: hidden size
|
563 |
+
nq: number of query heads
|
564 |
+
nkv: number of key heads
|
565 |
+
hn: hidden size per head
|
566 |
+
hn = h // nq
|
567 |
+
nqp: number of query heads (per MP partition)
|
568 |
+
nqp = nq // (num mp partitions)
|
569 |
+
nkvp: number of key-value heads (per MP partition)
|
570 |
+
nkvp = nk // (num mp partitions)
|
571 |
+
|
572 |
+
"""
|
573 |
+
# shape: (bs, q_len, h)
|
574 |
+
bsz, q_len, _ = hidden_states.size()
|
575 |
+
|
576 |
+
# shape: (bs, q_len, (nqp + 2 * nkvp) * hn)
|
577 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
578 |
+
# shape: (bs, q_len, nqp, hn), shape: (bs, q_len, nkvp, hn), shape: (bs, q_len, nkvp, hn)
|
579 |
+
q, k, v = self._split_heads(mixed_x_layer)
|
580 |
+
|
581 |
+
# shape: (bs, qnp, q_len, hn)
|
582 |
+
query_states = q.permute(0, 2, 1, 3).contiguous()
|
583 |
+
# shape: (bs, nkvp, q_len, hn)
|
584 |
+
key_states = k.permute(0, 2, 1, 3).contiguous()
|
585 |
+
# shape: (bs, nkvp, q_len, hn)
|
586 |
+
value_states = v.permute(0, 2, 1, 3).contiguous()
|
587 |
+
|
588 |
+
kv_seq_len = key_states.shape[-2]
|
589 |
+
if past_key_values is not None:
|
590 |
+
if self.layer_idx is None:
|
591 |
+
raise ValueError(
|
592 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
593 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
594 |
+
"with a layer index."
|
595 |
+
)
|
596 |
+
if self.rotary_emb is not None:
|
597 |
+
seqlen_offset = past_key_values.get_usable_length(kv_seq_len, layer_idx=self.layer_idx)
|
598 |
+
# shape: (bs, nqp, q_len, hn), shape: (bs, nkvp, q_len, hn)
|
599 |
+
query_states, key_states = self.rotary_emb(
|
600 |
+
query_states, key_states, seq_dimension=2, seqlen_offset=seqlen_offset
|
601 |
+
)
|
602 |
+
key_states, value_states = past_key_values.update(key_states=key_states, value_states=value_states, layer_idx=self.layer_idx)
|
603 |
+
else:
|
604 |
+
# In this case seq_len = q_len and cache_sq = 0
|
605 |
+
if self.rotary_emb is not None:
|
606 |
+
# shape: (bs, nqp, seq_len, hn), shape: (bs, nkvp, seq_len, hn)
|
607 |
+
query_states, key_states = self.rotary_emb(query_states, key_states, seq_dimension=2)
|
608 |
+
|
609 |
+
# shape: (bs, nkvp, 2, seq_len, hn)
|
610 |
+
kv_states = torch.cat((key_states.unsqueeze(2), value_states.unsqueeze(2)), dim=2)
|
611 |
+
# shape: (bs, nqp, 2, seq_len, hn)
|
612 |
+
expanded_kv_states = self.expand_kv_to_q_size(kv_states, num_q_per_kv=self.num_q_per_kv)
|
613 |
+
# shape: (bs, nqp, seq_len, hn), shape: (bs, nqp, seq_len, hn)
|
614 |
+
expanded_key_states, expanded_value_states = expanded_kv_states[:, :, 0], expanded_kv_states[:, :, 1]
|
615 |
+
if self.blocksparse:
|
616 |
+
attn_function_output = self._apply_blocksparse_attention(
|
617 |
+
q=query_states,
|
618 |
+
k=expanded_key_states,
|
619 |
+
v=expanded_value_states,
|
620 |
+
attention_mask=attention_mask,
|
621 |
+
return_attention_probs=output_attentions
|
622 |
+
)
|
623 |
+
else:
|
624 |
+
attn_function_output = self._apply_dense_attention(
|
625 |
+
q=query_states,
|
626 |
+
k=expanded_key_states,
|
627 |
+
v=expanded_value_states,
|
628 |
+
attention_mask=attention_mask,
|
629 |
+
return_attention_probs=output_attentions
|
630 |
+
)
|
631 |
+
|
632 |
+
attn_weights = None
|
633 |
+
if output_attentions:
|
634 |
+
attn_output, attn_weights = attn_function_output
|
635 |
+
else:
|
636 |
+
# shape: (bs, nqp, seq_len, hn)
|
637 |
+
attn_output = attn_function_output
|
638 |
+
# shape: (bs, seq_len, nqp, hn)
|
639 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
640 |
+
|
641 |
+
# shape: (bs, seq_len, h)
|
642 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
643 |
+
attn_output = self.dense(attn_output)
|
644 |
+
return attn_output, attn_weights, past_key_values
|
645 |
+
|
646 |
+
|
647 |
+
class Phi3SmallDecoderLayer(nn.Module):
|
648 |
+
def __init__(self, config: Phi3SmallConfig, layer_idx: int):
|
649 |
+
super().__init__()
|
650 |
+
self.hidden_size = config.hidden_size
|
651 |
+
self.self_attn = Phi3SmallSelfAttention(config, layer_idx)
|
652 |
+
self.mlp = Phi3SmallMLP(config)
|
653 |
+
|
654 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
655 |
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
656 |
+
|
657 |
+
def forward(
|
658 |
+
self,
|
659 |
+
hidden_states: torch.Tensor,
|
660 |
+
attention_mask: Optional[torch.Tensor] = None,
|
661 |
+
position_ids: Optional[torch.LongTensor] = None,
|
662 |
+
past_key_values: Optional[Cache] = None,
|
663 |
+
output_attentions: Optional[bool] = None,
|
664 |
+
use_cache: Optional[bool] = None,
|
665 |
+
**kwargs,
|
666 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Cache]]:
|
667 |
+
residual = hidden_states
|
668 |
+
hidden_states = self.input_layernorm(hidden_states)
|
669 |
+
|
670 |
+
# Self Attention
|
671 |
+
hidden_states, self_attn_weights, present_key_values = self.self_attn(
|
672 |
+
hidden_states=hidden_states,
|
673 |
+
attention_mask=attention_mask,
|
674 |
+
position_ids=position_ids,
|
675 |
+
past_key_values=past_key_values,
|
676 |
+
output_attentions=output_attentions,
|
677 |
+
use_cache=use_cache,
|
678 |
+
)
|
679 |
+
hidden_states = residual + hidden_states
|
680 |
+
|
681 |
+
# Fully Connected
|
682 |
+
residual = hidden_states
|
683 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
684 |
+
hidden_states = self.mlp(hidden_states)
|
685 |
+
hidden_states = residual + hidden_states
|
686 |
+
|
687 |
+
outputs = (hidden_states,)
|
688 |
+
|
689 |
+
if output_attentions:
|
690 |
+
outputs += (self_attn_weights,)
|
691 |
+
|
692 |
+
if use_cache:
|
693 |
+
outputs += (present_key_values,)
|
694 |
+
|
695 |
+
return outputs
|
696 |
+
|
697 |
+
|
698 |
+
|
699 |
+
class Phi3SmallPreTrainedModel(PreTrainedModel):
|
700 |
+
config_class = Phi3SmallConfig
|
701 |
+
base_model_prefix = "model"
|
702 |
+
supports_gradient_checkpointing = True
|
703 |
+
_no_split_modules = ["Phi3SmallDecoderLayer"]
|
704 |
+
skip_keys_device_placement = "past_key_values"
|
705 |
+
_supports_flash_attn_2 = True
|
706 |
+
_supports_sdpa = False
|
707 |
+
_supports_cache_class = True
|
708 |
+
|
709 |
+
def _init_weights(self, module: nn.Module):
|
710 |
+
std = self.config.initializer_range
|
711 |
+
if isinstance(module, nn.Linear):
|
712 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
713 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
714 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
715 |
+
elif isinstance(module, nn.Embedding):
|
716 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
717 |
+
if module.padding_idx is not None:
|
718 |
+
module.weight.data[module.padding_idx].zero_()
|
719 |
+
elif isinstance(module, nn.LayerNorm):
|
720 |
+
module.bias.data.zero_()
|
721 |
+
module.weight.data.fill_(1.0)
|
722 |
+
|
723 |
+
# The output projection on the decoder attention layer as well as the down_proj in the MLP are scaled
|
724 |
+
# differently (dubbed `output_layer_init_method` in the Megatron code). This is replicated here
|
725 |
+
for name, p in module.named_parameters():
|
726 |
+
if any(x in name for x in ("c_proj.weight", "down_proj.weight", "o_proj.weight")):
|
727 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
728 |
+
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers)))
|
729 |
+
|
730 |
+
|
731 |
+
class Phi3SmallModel(Phi3SmallPreTrainedModel):
|
732 |
+
|
733 |
+
def __init__(self, config):
|
734 |
+
super().__init__(config)
|
735 |
+
self.config = config
|
736 |
+
|
737 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
738 |
+
|
739 |
+
# Embedding Dropout
|
740 |
+
self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob)
|
741 |
+
|
742 |
+
# MuP Embedding scaling
|
743 |
+
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
744 |
+
|
745 |
+
self.layers = nn.ModuleList([Phi3SmallDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
746 |
+
|
747 |
+
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
748 |
+
|
749 |
+
self.gradient_checkpointing = False
|
750 |
+
|
751 |
+
# Initialize weights and apply final processing
|
752 |
+
self.post_init()
|
753 |
+
|
754 |
+
def get_input_embeddings(self):
|
755 |
+
return self.embed_tokens
|
756 |
+
|
757 |
+
def set_input_embeddings(self, value):
|
758 |
+
self.embed_tokens = value
|
759 |
+
|
760 |
+
@property
|
761 |
+
def pad_sequence_to_multiple_of_64(self):
|
762 |
+
# We only need to do this for the backward pass. So only required
|
763 |
+
# when we are in the context of generating gradients
|
764 |
+
return self.config.pad_sequence_to_multiple_of_64 and torch.is_grad_enabled()
|
765 |
+
|
766 |
+
def forward(
|
767 |
+
self,
|
768 |
+
input_ids: torch.LongTensor = None,
|
769 |
+
attention_mask: Optional[torch.Tensor] = None,
|
770 |
+
position_ids: Optional[torch.LongTensor] = None,
|
771 |
+
past_key_values: Optional[Union[Cache, LegacyCache]] = None,
|
772 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
773 |
+
use_cache: Optional[bool] = None,
|
774 |
+
output_attentions: Optional[bool] = None,
|
775 |
+
output_hidden_states: Optional[bool] = None,
|
776 |
+
return_dict: Optional[bool] = None,
|
777 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
778 |
+
|
779 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
780 |
+
output_hidden_states = (
|
781 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
782 |
+
)
|
783 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
784 |
+
|
785 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
786 |
+
|
787 |
+
if input_ids is not None and inputs_embeds is not None:
|
788 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
789 |
+
elif input_ids is not None:
|
790 |
+
batch_size, seq_length = input_ids.shape
|
791 |
+
elif inputs_embeds is not None:
|
792 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
793 |
+
else:
|
794 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
795 |
+
|
796 |
+
if self.gradient_checkpointing and self.training:
|
797 |
+
if use_cache:
|
798 |
+
logger.warning_once(
|
799 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
800 |
+
)
|
801 |
+
use_cache = False
|
802 |
+
|
803 |
+
past_key_values_length = 0
|
804 |
+
|
805 |
+
if use_cache:
|
806 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
807 |
+
if use_legacy_cache:
|
808 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
809 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
810 |
+
|
811 |
+
if position_ids is None:
|
812 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
813 |
+
position_ids = torch.arange(
|
814 |
+
past_key_values_length, past_key_values_length + seq_length, dtype=torch.long, device=device
|
815 |
+
)
|
816 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
817 |
+
else:
|
818 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
819 |
+
|
820 |
+
if attention_mask is not None:
|
821 |
+
if batch_size <= 0:
|
822 |
+
raise ValueError("batch_size has to be defined and > 0")
|
823 |
+
|
824 |
+
if inputs_embeds is None:
|
825 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
826 |
+
inputs_embeds = self.embedding_dropout(inputs_embeds)
|
827 |
+
|
828 |
+
if self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0:
|
829 |
+
inputs_embeds = inputs_embeds * self.mup_embedding_multiplier
|
830 |
+
|
831 |
+
residual = 0
|
832 |
+
if self.pad_sequence_to_multiple_of_64:
|
833 |
+
# note(bapatra): Since we don't particularly use the position_ids and the attention mask
|
834 |
+
# we don't need to pad them
|
835 |
+
inputs_embeds, residual = pad_tensor_to_next_mult_of(tensor=inputs_embeds, dim=1, n=64)
|
836 |
+
|
837 |
+
hidden_states = inputs_embeds
|
838 |
+
|
839 |
+
# decoder layers
|
840 |
+
all_hidden_states = () if output_hidden_states else None
|
841 |
+
all_self_attns = () if output_attentions else None
|
842 |
+
next_decoder_cache = None
|
843 |
+
|
844 |
+
for decoder_layer in self.layers:
|
845 |
+
if output_hidden_states:
|
846 |
+
all_hidden_states += (hidden_states,)
|
847 |
+
|
848 |
+
if self.gradient_checkpointing and self.training:
|
849 |
+
layer_outputs = self._gradient_checkpointing_func(
|
850 |
+
decoder_layer.__call__,
|
851 |
+
hidden_states,
|
852 |
+
attention_mask,
|
853 |
+
position_ids,
|
854 |
+
past_key_values,
|
855 |
+
output_attentions,
|
856 |
+
use_cache,
|
857 |
+
)
|
858 |
+
else:
|
859 |
+
layer_outputs = decoder_layer(
|
860 |
+
hidden_states,
|
861 |
+
attention_mask=attention_mask,
|
862 |
+
position_ids=position_ids,
|
863 |
+
past_key_values=past_key_values,
|
864 |
+
output_attentions=output_attentions,
|
865 |
+
use_cache=use_cache,
|
866 |
+
)
|
867 |
+
hidden_states = layer_outputs[0]
|
868 |
+
|
869 |
+
if use_cache:
|
870 |
+
# Following the Mistral schema for layer return values
|
871 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
872 |
+
if output_attentions:
|
873 |
+
all_self_attns += (layer_outputs[1],)
|
874 |
+
|
875 |
+
hidden_states = self.final_layernorm(hidden_states)
|
876 |
+
|
877 |
+
if residual > 0:
|
878 |
+
hidden_states = strip_padding_from_tensor(tensor=hidden_states, dim=1, residual=residual)
|
879 |
+
|
880 |
+
# add hidden states from the last decoder layer
|
881 |
+
if output_hidden_states:
|
882 |
+
all_hidden_states += (hidden_states,)
|
883 |
+
|
884 |
+
next_cache = None
|
885 |
+
if use_cache:
|
886 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
887 |
+
|
888 |
+
if not return_dict:
|
889 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
890 |
+
return BaseModelOutputWithPast(
|
891 |
+
last_hidden_state=hidden_states,
|
892 |
+
past_key_values=next_cache,
|
893 |
+
hidden_states=all_hidden_states,
|
894 |
+
attentions=all_self_attns,
|
895 |
+
)
|
896 |
+
|
897 |
+
|
898 |
+
class Phi3SmallForCausalLM(Phi3SmallPreTrainedModel):
|
899 |
+
_tied_weights_keys = ["lm_head.weight"]
|
900 |
+
|
901 |
+
def __init__(self, config):
|
902 |
+
super().__init__(config)
|
903 |
+
self.model = Phi3SmallModel(config)
|
904 |
+
self.vocab_size = config.vocab_size
|
905 |
+
self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
|
906 |
+
self.mup_width_multiplier = config.mup_width_multiplier
|
907 |
+
|
908 |
+
# Create the mask for the dummy tokens in the vocabulary
|
909 |
+
dummy_token_indices = config.dummy_token_indices
|
910 |
+
dummy_tokens_mask = torch.zeros(self.vocab_size).bool()
|
911 |
+
dummy_tokens_mask[dummy_token_indices] = True
|
912 |
+
# shape: (vocab_size,)
|
913 |
+
self.register_buffer("dummy_tokens_mask", dummy_tokens_mask, persistent=False)
|
914 |
+
|
915 |
+
# Initialize weights and apply final processing
|
916 |
+
self.post_init()
|
917 |
+
|
918 |
+
def get_input_embeddings(self):
|
919 |
+
return self.model.embed_tokens
|
920 |
+
|
921 |
+
def set_input_embeddings(self, value):
|
922 |
+
self.model.embed_tokens = value
|
923 |
+
|
924 |
+
def get_output_embeddings(self):
|
925 |
+
return self.lm_head
|
926 |
+
|
927 |
+
def set_output_embeddings(self, value):
|
928 |
+
self.lm_head = value
|
929 |
+
|
930 |
+
def set_decoder(self, decoder):
|
931 |
+
self.model = decoder
|
932 |
+
|
933 |
+
def get_decoder(self):
|
934 |
+
return self.model
|
935 |
+
|
936 |
+
def forward(
|
937 |
+
self,
|
938 |
+
input_ids: torch.LongTensor = None,
|
939 |
+
attention_mask: Optional[torch.Tensor] = None,
|
940 |
+
position_ids: Optional[torch.LongTensor] = None,
|
941 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
942 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
943 |
+
labels: Optional[torch.LongTensor] = None,
|
944 |
+
use_cache: Optional[bool] = None,
|
945 |
+
output_attentions: Optional[bool] = None,
|
946 |
+
output_hidden_states: Optional[bool] = None,
|
947 |
+
return_dict: Optional[bool] = None,
|
948 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
949 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
950 |
+
output_hidden_states = (
|
951 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
952 |
+
)
|
953 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
954 |
+
|
955 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
956 |
+
outputs = self.model(
|
957 |
+
input_ids=input_ids,
|
958 |
+
attention_mask=attention_mask,
|
959 |
+
position_ids=position_ids,
|
960 |
+
past_key_values=past_key_values,
|
961 |
+
inputs_embeds=inputs_embeds,
|
962 |
+
use_cache=use_cache,
|
963 |
+
output_attentions=output_attentions,
|
964 |
+
output_hidden_states=output_hidden_states,
|
965 |
+
return_dict=return_dict,
|
966 |
+
)
|
967 |
+
|
968 |
+
hidden_states = outputs[0]
|
969 |
+
logits = self.lm_head(hidden_states)
|
970 |
+
logits = logits.float()
|
971 |
+
if self.mup_width_multiplier:
|
972 |
+
logits = logits / self.mup_width_multiplier
|
973 |
+
logits = logits.masked_fill(self.dummy_tokens_mask, min_value_of_dtype(logits.dtype))
|
974 |
+
|
975 |
+
loss = None
|
976 |
+
if labels is not None:
|
977 |
+
# Shift so that tokens < n predict n
|
978 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
979 |
+
shift_labels = labels[..., 1:].contiguous()
|
980 |
+
# Flatten the tokens
|
981 |
+
loss_fct = nn.CrossEntropyLoss()
|
982 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
983 |
+
shift_labels = shift_labels.view(-1)
|
984 |
+
# Enable model parallelism
|
985 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
986 |
+
loss = loss_fct(shift_logits, shift_labels)
|
987 |
+
|
988 |
+
if not return_dict:
|
989 |
+
output = (logits,) + outputs[1:]
|
990 |
+
return (loss,) + output if loss is not None else output
|
991 |
+
|
992 |
+
return CausalLMOutputWithPast(
|
993 |
+
loss=loss,
|
994 |
+
logits=logits,
|
995 |
+
past_key_values=outputs.past_key_values,
|
996 |
+
hidden_states=outputs.hidden_states,
|
997 |
+
attentions=outputs.attentions,
|
998 |
+
)
|
999 |
+
|
1000 |
+
def prepare_inputs_for_generation(
|
1001 |
+
self,
|
1002 |
+
input_ids: torch.LongTensor,
|
1003 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1004 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1005 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1006 |
+
**kwargs
|
1007 |
+
) -> Dict[str, Any]:
|
1008 |
+
# only last token for inputs_ids if past is defined in kwargs
|
1009 |
+
if past_key_values:
|
1010 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1011 |
+
|
1012 |
+
position_ids = kwargs.get("position_ids", None)
|
1013 |
+
|
1014 |
+
if attention_mask is not None and position_ids is None:
|
1015 |
+
# create position_ids on the fly for batch generation
|
1016 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1017 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1018 |
+
if past_key_values:
|
1019 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1020 |
+
else:
|
1021 |
+
position_ids = None
|
1022 |
+
|
1023 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1024 |
+
if inputs_embeds is not None and past_key_values is None:
|
1025 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1026 |
+
else:
|
1027 |
+
model_inputs = {"input_ids": input_ids}
|
1028 |
+
|
1029 |
+
model_inputs.update(
|
1030 |
+
{
|
1031 |
+
"past_key_values": past_key_values,
|
1032 |
+
"use_cache": kwargs.get("use_cache"),
|
1033 |
+
"position_ids": position_ids,
|
1034 |
+
"attention_mask": attention_mask,
|
1035 |
+
}
|
1036 |
+
)
|
1037 |
+
return model_inputs
|
1038 |
+
|
1039 |
+
|
1040 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with Mistral -> Phi3Small
|
1041 |
+
class Phi3SmallForSequenceClassification(Phi3SmallPreTrainedModel):
|
1042 |
+
def __init__(self, config):
|
1043 |
+
super().__init__(config)
|
1044 |
+
self.num_labels = config.num_labels
|
1045 |
+
self.model = Phi3SmallModel(config)
|
1046 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1047 |
+
|
1048 |
+
# Initialize weights and apply final processing
|
1049 |
+
self.post_init()
|
1050 |
+
|
1051 |
+
def get_input_embeddings(self):
|
1052 |
+
return self.model.embed_tokens
|
1053 |
+
|
1054 |
+
def set_input_embeddings(self, value):
|
1055 |
+
self.model.embed_tokens = value
|
1056 |
+
|
1057 |
+
|
1058 |
+
def forward(
|
1059 |
+
self,
|
1060 |
+
input_ids: torch.LongTensor = None,
|
1061 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1062 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1063 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1064 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1065 |
+
labels: Optional[torch.LongTensor] = None,
|
1066 |
+
use_cache: Optional[bool] = None,
|
1067 |
+
output_attentions: Optional[bool] = None,
|
1068 |
+
output_hidden_states: Optional[bool] = None,
|
1069 |
+
return_dict: Optional[bool] = None,
|
1070 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1071 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1072 |
+
|
1073 |
+
transformer_outputs = self.model(
|
1074 |
+
input_ids,
|
1075 |
+
attention_mask=attention_mask,
|
1076 |
+
position_ids=position_ids,
|
1077 |
+
past_key_values=past_key_values,
|
1078 |
+
inputs_embeds=inputs_embeds,
|
1079 |
+
use_cache=use_cache,
|
1080 |
+
output_attentions=output_attentions,
|
1081 |
+
output_hidden_states=output_hidden_states,
|
1082 |
+
return_dict=return_dict,
|
1083 |
+
)
|
1084 |
+
hidden_states = transformer_outputs[0]
|
1085 |
+
logits = self.score(hidden_states)
|
1086 |
+
|
1087 |
+
if input_ids is not None:
|
1088 |
+
batch_size = input_ids.shape[0]
|
1089 |
+
else:
|
1090 |
+
batch_size = inputs_embeds.shape[0]
|
1091 |
+
|
1092 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1093 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1094 |
+
if self.config.pad_token_id is None:
|
1095 |
+
sequence_lengths = -1
|
1096 |
+
else:
|
1097 |
+
if input_ids is not None:
|
1098 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1099 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1100 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1101 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1102 |
+
else:
|
1103 |
+
sequence_lengths = -1
|
1104 |
+
|
1105 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1106 |
+
|
1107 |
+
loss = None
|
1108 |
+
if labels is not None:
|
1109 |
+
labels = labels.to(logits.device)
|
1110 |
+
if self.config.problem_type is None:
|
1111 |
+
if self.num_labels == 1:
|
1112 |
+
self.config.problem_type = "regression"
|
1113 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1114 |
+
self.config.problem_type = "single_label_classification"
|
1115 |
+
else:
|
1116 |
+
self.config.problem_type = "multi_label_classification"
|
1117 |
+
|
1118 |
+
if self.config.problem_type == "regression":
|
1119 |
+
loss_fct = nn.MSELoss()
|
1120 |
+
if self.num_labels == 1:
|
1121 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1122 |
+
else:
|
1123 |
+
loss = loss_fct(pooled_logits, labels)
|
1124 |
+
elif self.config.problem_type == "single_label_classification":
|
1125 |
+
loss_fct = nn.CrossEntropyLoss()
|
1126 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1127 |
+
elif self.config.problem_type == "multi_label_classification":
|
1128 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
1129 |
+
loss = loss_fct(pooled_logits, labels)
|
1130 |
+
if not return_dict:
|
1131 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1132 |
+
return ((loss,) + output) if loss is not None else output
|
1133 |
+
|
1134 |
+
return SequenceClassifierOutputWithPast(
|
1135 |
+
loss=loss,
|
1136 |
+
logits=pooled_logits,
|
1137 |
+
past_key_values=transformer_outputs.past_key_values,
|
1138 |
+
hidden_states=transformer_outputs.hidden_states,
|
1139 |
+
attentions=transformer_outputs.attentions,
|
1140 |
+
)
|
positional_embedding.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Orginally Taken verbatim from xformers library
|
3 |
+
https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py
|
4 |
+
|
5 |
+
The difference is that xformers seems to assume the inputs to be
|
6 |
+
(bs, head, seq_len, dim) while we assume (bs, seq_len, head, dim)
|
7 |
+
|
8 |
+
"""
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
10 |
+
#
|
11 |
+
# This source code is licensed under the BSD license found in the
|
12 |
+
# LICENSE file in the root directory of this source tree.
|
13 |
+
|
14 |
+
|
15 |
+
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
|
16 |
+
# NOTE: Almost the same right now, moving parts to Triton is the next step
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import List, Optional, Tuple, Dict, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import dataclasses
|
23 |
+
from transformers.utils import logging
|
24 |
+
|
25 |
+
from transformers import PretrainedConfig
|
26 |
+
|
27 |
+
is_dacite_available = False
|
28 |
+
try:
|
29 |
+
import dacite
|
30 |
+
is_dacite_available = True
|
31 |
+
except ImportError:
|
32 |
+
pass
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__)
|
35 |
+
|
36 |
+
@dataclasses.dataclass
|
37 |
+
class LongRopeConfig(object):
|
38 |
+
short_factor: List[float]
|
39 |
+
long_factor: List[float]
|
40 |
+
original_max_position_embeddings: int
|
41 |
+
type: str = "longrope"
|
42 |
+
short_mscale: float = -1
|
43 |
+
long_mscale: float = -1
|
44 |
+
|
45 |
+
|
46 |
+
def __post_init__(self):
|
47 |
+
assert self.type in ("longrope", "su"), f"Invalid type {self.type} for LongRopeConfig. Expected longrope / su"
|
48 |
+
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def from_dict(cls, config_dict: Dict[str, Union[float, List[float], int]]) -> "LongRopeConfig":
|
52 |
+
if is_dacite_available:
|
53 |
+
# Preferred since we can also type check the input
|
54 |
+
return dacite.from_dict(data_class=cls, data=config_dict)
|
55 |
+
kwargs = {}
|
56 |
+
for field in dataclasses.fields(cls):
|
57 |
+
if field.name in config_dict:
|
58 |
+
if field.init:
|
59 |
+
kwargs[field.name] = config_dict[field.name]
|
60 |
+
else:
|
61 |
+
raise ValueError(f"Field {field.name} is not initiable")
|
62 |
+
else:
|
63 |
+
if field.default is dataclasses.MISSING:
|
64 |
+
raise ValueError(f"Field {field.name} is required")
|
65 |
+
extra_keys = set(config_dict.keys()) - set(kwargs.keys())
|
66 |
+
if len(extra_keys) > 0:
|
67 |
+
for key in extra_keys:
|
68 |
+
logger.error(f"Unrecognized key {key} in config_dict")
|
69 |
+
raise ValueError(f"Unrecognized keys in config_dict")
|
70 |
+
return cls(**kwargs)
|
71 |
+
|
72 |
+
def rotate_half(x):
|
73 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
74 |
+
return torch.cat((-x2, x1), dim=x1.ndim - 1)
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
@torch.jit.script
|
79 |
+
def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int):
|
80 |
+
# NOTE: This could probably be moved to Triton
|
81 |
+
|
82 |
+
if seq_dimension == 0:
|
83 |
+
cos = cos[: x.shape[0], None, None, :]
|
84 |
+
sin = sin[: x.shape[0], None, None, :]
|
85 |
+
elif seq_dimension == 1:
|
86 |
+
# Handle a possible sequence length mismatch in between q and k
|
87 |
+
cos = cos[None, : x.shape[1], None, :]
|
88 |
+
sin = sin[None, : x.shape[1], None, :]
|
89 |
+
elif seq_dimension == 2:
|
90 |
+
cos = cos[None, None, : x.shape[2], :]
|
91 |
+
sin = sin[None, None, : x.shape[2], :]
|
92 |
+
|
93 |
+
return (x * cos) + (rotate_half(x) * sin)
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
class RotaryEmbedding(torch.nn.Module):
|
98 |
+
"""
|
99 |
+
Adapted from the xformers library
|
100 |
+
|
101 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
102 |
+
A crucial insight from the method is that the query and keys are
|
103 |
+
transformed by rotation matrices which depend on the relative positions.
|
104 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
105 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
106 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
107 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
108 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
109 |
+
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
|
110 |
+
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
|
111 |
+
|
112 |
+
# Arguments
|
113 |
+
:param dim_mode: head dimention
|
114 |
+
:param max_seq_len:
|
115 |
+
:param default_seq_dimension: which dim is the sequence length
|
116 |
+
:param dtype: cos/sin dtype
|
117 |
+
:param use_fused_kernel: if to use customized fused kernel.
|
118 |
+
Note: if used, q, k will be modified inplace. Ok for both forward & backward.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
dim_model: int,
|
124 |
+
*,
|
125 |
+
max_seq_len: Optional[int] = None,
|
126 |
+
dtype: Optional[torch.dtype] = None,
|
127 |
+
base=10000,
|
128 |
+
position_scale=1,
|
129 |
+
device: Optional[torch.device] = None,
|
130 |
+
longrope_config: Optional[LongRopeConfig] = None,
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
self.base = base
|
134 |
+
self.dim_model = dim_model
|
135 |
+
self.max_seq_len = max_seq_len
|
136 |
+
self.longrope_config = longrope_config
|
137 |
+
|
138 |
+
if self.is_longrope:
|
139 |
+
# Keep the maximum range vector, and slice from it as needed
|
140 |
+
self.register_buffer(
|
141 |
+
"range_vector",
|
142 |
+
torch.arange(max_seq_len, device=device, dtype=torch.float32),
|
143 |
+
persistent=False
|
144 |
+
)
|
145 |
+
self.register_buffer(
|
146 |
+
"short_factors",
|
147 |
+
torch.tensor(self.longrope_config.short_factor, dtype=torch.float32),
|
148 |
+
persistent=False
|
149 |
+
)
|
150 |
+
self.register_buffer(
|
151 |
+
"long_factors",
|
152 |
+
torch.tensor(self.longrope_config.long_factor, dtype=torch.float32),
|
153 |
+
persistent=False
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
157 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim_model, 2).float().to(device) / self.dim_model))
|
158 |
+
self.register_buffer("inv_freq", inv_freq)
|
159 |
+
|
160 |
+
self.position_scale = position_scale
|
161 |
+
|
162 |
+
if not self.is_longrope:
|
163 |
+
dtype = dtype or torch.get_default_dtype()
|
164 |
+
self._set_cos_sin_cache(
|
165 |
+
seq_len=max_seq_len,
|
166 |
+
device=self.inv_freq.device,
|
167 |
+
dtype=dtype,
|
168 |
+
)
|
169 |
+
@property
|
170 |
+
def is_longrope(self):
|
171 |
+
return self.longrope_config is not None
|
172 |
+
|
173 |
+
@property
|
174 |
+
def original_max_seq_len(self):
|
175 |
+
if self.longrope_config is not None:
|
176 |
+
return self.longrope_config.original_max_position_embeddings
|
177 |
+
logger.warning_once(
|
178 |
+
(
|
179 |
+
"``original_max_seq_len'' is being accessed, but longrope_config has not been set. "
|
180 |
+
"Please only do this if you are sure about the context."
|
181 |
+
)
|
182 |
+
)
|
183 |
+
return self.max_seq_len
|
184 |
+
|
185 |
+
def get_range_vector(self, seq_len: int, device: torch.device):
|
186 |
+
if self.is_longrope:
|
187 |
+
assert seq_len < self.range_vector.shape[0], f"Found seq_len {seq_len} greater than max_seq_len {self.range_vector.shape[0]}"
|
188 |
+
if self.range_vector.device != device:
|
189 |
+
self.range_vector = self.range_vector.to(device)
|
190 |
+
return self.range_vector[:seq_len]
|
191 |
+
return torch.arange(seq_len, device=device, dtype=torch.float32)
|
192 |
+
|
193 |
+
|
194 |
+
def _calc_mscale(self, scale: torch.Tensor) -> torch.Tensor:
|
195 |
+
if scale <= 1.0:
|
196 |
+
return 1.0
|
197 |
+
return math.sqrt(1 + math.log(scale) / math.log(self.original_max_seq_len))
|
198 |
+
|
199 |
+
def _set_cos_sin_cache(
|
200 |
+
self,
|
201 |
+
seq_len: int,
|
202 |
+
device: Optional[torch.device] = None,
|
203 |
+
dtype: Optional[torch.dtype] = None,
|
204 |
+
) -> None:
|
205 |
+
dtype = dtype or torch.get_default_dtype()
|
206 |
+
self.max_seq_len_cached = seq_len
|
207 |
+
t = (torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) * self.position_scale).type_as(self.inv_freq)
|
208 |
+
device_type = device.type if device is not None else "cpu"
|
209 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
210 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
211 |
+
# shape: (seq_len, dim_model // 2)
|
212 |
+
freqs = torch.outer(t, self.inv_freq)
|
213 |
+
# shape: (seq_len, dim_model)
|
214 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
215 |
+
cos = emb.cos()
|
216 |
+
sin = emb.sin()
|
217 |
+
self.register_buffer("cos_cached", cos.to(dtype), persistent=False)
|
218 |
+
self.register_buffer("sin_cached", sin.to(dtype), persistent=False)
|
219 |
+
|
220 |
+
def forward(
|
221 |
+
self, q: torch.Tensor,
|
222 |
+
k: torch.Tensor,
|
223 |
+
seq_dimension: int = 1,
|
224 |
+
seqlen_offset: int = 0,
|
225 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
226 |
+
"""q, k does not include `seqlen_offset`
|
227 |
+
q: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
|
228 |
+
k: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
|
229 |
+
"""
|
230 |
+
if seq_dimension < 0:
|
231 |
+
seq_dimension = k.ndim + seq_dimension
|
232 |
+
assert seq_dimension in (0, 1, 2)
|
233 |
+
seq_len = k.shape[seq_dimension] + seqlen_offset
|
234 |
+
|
235 |
+
if self.is_longrope:
|
236 |
+
if seq_len > self.original_max_seq_len:
|
237 |
+
t = self.get_range_vector(seq_len, device=q.device)
|
238 |
+
rescale_factors = self.long_factors.to(q.device)
|
239 |
+
long_mscale = self.longrope_config.long_mscale
|
240 |
+
mscale = long_mscale if long_mscale > 0 else self._calc_mscale(self.max_seq_len / self.original_max_seq_len)
|
241 |
+
else:
|
242 |
+
t = self.get_range_vector(self.original_max_seq_len, device=q.device)
|
243 |
+
rescale_factors = self.short_factors.to(q.device)
|
244 |
+
short_mscale = self.longrope_config.short_mscale
|
245 |
+
mscale = short_mscale if short_mscale > 0 else 1.0
|
246 |
+
assert rescale_factors.shape == (self.dim_model // 2, ), (
|
247 |
+
f"misaligned shape for LongRoPE rescale factors:\n"
|
248 |
+
f"\tExpected {(self.dim_model // 2, )}, got {rescale_factors.shape}."
|
249 |
+
)
|
250 |
+
inv_freq = 1.0 / (rescale_factors * (self.base ** (torch.arange(0, self.dim_model, 2).float().to(q.device) / self.dim_model)))
|
251 |
+
device_type = q.device.type if q.device is not None else "cpu"
|
252 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
253 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
254 |
+
freqs = torch.outer(t, inv_freq)
|
255 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
256 |
+
cos = emb.cos() * mscale
|
257 |
+
sin = emb.sin() * mscale
|
258 |
+
cos_cached = cos.to(q.dtype)
|
259 |
+
sin_cached = sin.to(q.dtype)
|
260 |
+
else:
|
261 |
+
if seq_len > self.max_seq_len_cached:
|
262 |
+
self._set_cos_sin_cache(
|
263 |
+
seq_len=seq_len,
|
264 |
+
device=k.device,
|
265 |
+
dtype=k.dtype,
|
266 |
+
)
|
267 |
+
cos_cached = self.cos_cached
|
268 |
+
sin_cached = self.sin_cached
|
269 |
+
return (
|
270 |
+
apply_rotary_pos_emb(
|
271 |
+
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
272 |
+
),
|
273 |
+
apply_rotary_pos_emb(
|
274 |
+
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
275 |
+
),
|
276 |
+
)
|
277 |
+
|
278 |
+
@classmethod
|
279 |
+
def from_config(cls, config: PretrainedConfig) -> "RotaryEmbedding":
|
280 |
+
kwargs = dict(
|
281 |
+
dim_model=config.hidden_size // config.num_attention_heads,
|
282 |
+
max_seq_len=config.max_position_embeddings,
|
283 |
+
base=config.rope_embedding_base,
|
284 |
+
position_scale=config.rope_position_scale,
|
285 |
+
)
|
286 |
+
if config.rope_scaling is not None:
|
287 |
+
kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
|
288 |
+
return cls(**kwargs)
|
special_tokens_map.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<|endoftext|>",
|
3 |
+
"eos_token": "<|endoftext|>",
|
4 |
+
"pad_token": "<|endoftext|>"
|
5 |
+
}
|
tokenization_phi3_small.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/tokenization_qwen.py
|
2 |
+
import os
|
3 |
+
from typing import Collection, List, Optional, Dict, Set, Tuple, Union
|
4 |
+
|
5 |
+
from functools import cached_property
|
6 |
+
|
7 |
+
import base64
|
8 |
+
|
9 |
+
from transformers import PreTrainedTokenizer, AddedToken, AutoConfig
|
10 |
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
11 |
+
import tiktoken
|
12 |
+
|
13 |
+
|
14 |
+
"""
|
15 |
+
This tokenizer is almost identical to tiktoken.get_encoding("cl100k_base")
|
16 |
+
with a few additional special tokens to support the ChatML format.
|
17 |
+
|
18 |
+
TODO(bapatra): Right now, I do not save the special tokens to the vocab file.
|
19 |
+
Maybe in the future, that would be useful? Can add that support later.
|
20 |
+
|
21 |
+
"""
|
22 |
+
|
23 |
+
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
24 |
+
with open(tiktoken_bpe_file, "rb") as f:
|
25 |
+
contents = f.read()
|
26 |
+
return {
|
27 |
+
base64.b64decode(token): int(rank)
|
28 |
+
for token, rank in (line.split() for line in contents.splitlines() if line)
|
29 |
+
}
|
30 |
+
|
31 |
+
# On the megatron codebase, we pad vocabularies to ensure matrix multiplication is fast.
|
32 |
+
# this in turn causes some indices to be empty. We account for these empty indices by adding
|
33 |
+
# dummy tokens to the tokenizer.
|
34 |
+
|
35 |
+
EFFECTIVE_PADDED_VOCAB_SIZE = 100352
|
36 |
+
ACTUAL_VOCAB_SIZE = 100276
|
37 |
+
|
38 |
+
|
39 |
+
DUMMY_TOKENS = {
|
40 |
+
f"<|dummy_id_{11 + offset}|>": 100276 + offset
|
41 |
+
for offset in range(1, EFFECTIVE_PADDED_VOCAB_SIZE - ACTUAL_VOCAB_SIZE)
|
42 |
+
}
|
43 |
+
|
44 |
+
SPECIAL_TOKENS = {
|
45 |
+
# tiktoken.get_encoding("cl100k_base")._special_tokens
|
46 |
+
'<|endoftext|>': 100257,
|
47 |
+
'<|fim_prefix|>': 100258,
|
48 |
+
'<|fim_middle|>': 100259,
|
49 |
+
'<|fim_suffix|>': 100260,
|
50 |
+
# Special tokens for post-training
|
51 |
+
"<|system|>": 100261,
|
52 |
+
"<|user|>": 100262,
|
53 |
+
"<|assistant|>": 100263,
|
54 |
+
# Dummy unused tokens
|
55 |
+
"<|dummy_id_0|>": 100264,
|
56 |
+
"<|dummy_id_1|>": 100265,
|
57 |
+
# Special tokens for post-training continued
|
58 |
+
"<|end|>": 100266,
|
59 |
+
# Some dummy tokens, so that tokenization is contiguous and does not cause issues
|
60 |
+
# Note that the 100256th token of tiktoken.get_encoding("cl100k_base") does not
|
61 |
+
# actually map to anything. So we use a dummy token here.
|
62 |
+
"<|dummy_id_2|>": 100256,
|
63 |
+
# Likewise, tokens from 100267 to 100275 are also unused
|
64 |
+
"<|dummy_id_3|>": 100267,
|
65 |
+
"<|dummy_id_4|>": 100268,
|
66 |
+
"<|dummy_id_5|>": 100269,
|
67 |
+
"<|dummy_id_6|>": 100270,
|
68 |
+
"<|dummy_id_7|>": 100271,
|
69 |
+
"<|dummy_id_8|>": 100272,
|
70 |
+
"<|dummy_id_9|>": 100273,
|
71 |
+
"<|dummy_id_10|>": 100274,
|
72 |
+
"<|dummy_id_11|>": 100275,
|
73 |
+
# The final end of prompt token
|
74 |
+
# (unused, but present as a part of tiktoken.get_encoding("cl100k_base")._special_tokens)
|
75 |
+
'<|endofprompt|>': 100276,
|
76 |
+
# Dummy tokens to account for padding of the tokenizer
|
77 |
+
# We pad to ensure tensor cores are used for vocab multiplication
|
78 |
+
**DUMMY_TOKENS
|
79 |
+
}
|
80 |
+
|
81 |
+
class Phi3SmallTokenizer(PreTrainedTokenizer):
|
82 |
+
vocab_files_names = {
|
83 |
+
"vocab_file": "cl100k_base.tiktoken"
|
84 |
+
}
|
85 |
+
|
86 |
+
model_input_names: List[str] = ["input_ids", "attention_mask"]
|
87 |
+
padding_side = "left"
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
vocab_file: Optional[str] = None,
|
92 |
+
errors: str = "replace",
|
93 |
+
**kwargs
|
94 |
+
) -> None:
|
95 |
+
# PreTrainedTokenizer's init calls _add_tokens, which in turn checks
|
96 |
+
# if the token is present in `self.special_tokens``. Hence instantiating it here.
|
97 |
+
# The way Qwen gets around this is by checking against SPECIAL_TOKENS
|
98 |
+
# But I think it's better to check against the objects own `special_tokens`
|
99 |
+
# in case we eventually want to allow the tokenizer to have special tokens.
|
100 |
+
self.special_tokens = SPECIAL_TOKENS
|
101 |
+
|
102 |
+
super().__init__(**kwargs)
|
103 |
+
self.errors = errors
|
104 |
+
|
105 |
+
base = tiktoken.get_encoding("cl100k_base")
|
106 |
+
if vocab_file is None:
|
107 |
+
self.mergeable_ranks: Dict[bytes, int] = base._mergeable_ranks
|
108 |
+
else:
|
109 |
+
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)
|
110 |
+
|
111 |
+
self.pat_str = base._pat_str
|
112 |
+
|
113 |
+
enc = tiktoken.Encoding(
|
114 |
+
name="phi3small",
|
115 |
+
pat_str=self.pat_str,
|
116 |
+
mergeable_ranks=self.mergeable_ranks,
|
117 |
+
special_tokens=self.special_tokens,
|
118 |
+
)
|
119 |
+
self.tokenizer = enc
|
120 |
+
|
121 |
+
self.decoder: Dict[int, bytes] = {
|
122 |
+
v: k for k, v in self.mergeable_ranks.items()
|
123 |
+
}
|
124 |
+
self.decoder.update({v: k for k, v in self.special_tokens.items()})
|
125 |
+
|
126 |
+
self.eod_id = self.tokenizer.eot_token
|
127 |
+
self._eos_token = self._convert_id_to_token(self.eod_id)
|
128 |
+
|
129 |
+
# Setting the bos_token to be the same as the eos_token
|
130 |
+
# Note that this is **not** the correct thing to do, and is done
|
131 |
+
# just so that some of the downstream libraries do not break.
|
132 |
+
self._bos_token = self._eos_token
|
133 |
+
|
134 |
+
# Assign the special tokens to class variables
|
135 |
+
self.system_id = self.special_tokens["<|system|>"]
|
136 |
+
self.user_id = self.special_tokens["<|user|>"]
|
137 |
+
self.assistant_id = self.special_tokens["<|assistant|>"]
|
138 |
+
self.end_id = self.special_tokens["<|end|>"]
|
139 |
+
|
140 |
+
@cached_property
|
141 |
+
def dummy_token_indices(self) -> List[int]:
|
142 |
+
# There are some additional special tokens in the cl100k_base tokenizer
|
143 |
+
# that we do not use. Hence, we also consider them to be dummy tokens.
|
144 |
+
additional_tokens = [
|
145 |
+
"<|fim_prefix|>",
|
146 |
+
"<|fim_middle|>",
|
147 |
+
"<|fim_suffix|>",
|
148 |
+
"<|endofprompt|>"
|
149 |
+
]
|
150 |
+
dummy_token_indices = [index for token, index in self.special_tokens.items() if "dummy_id" in token]
|
151 |
+
dummy_token_indices.extend([self.special_tokens[token] for token in additional_tokens])
|
152 |
+
return sorted(dummy_token_indices)
|
153 |
+
|
154 |
+
def __getstate__(self):
|
155 |
+
state = self.__dict__.copy()
|
156 |
+
del state["tokenizer"]
|
157 |
+
return state
|
158 |
+
|
159 |
+
def __setstate__(self, state):
|
160 |
+
self.__dict__ = state
|
161 |
+
enc = tiktoken.Encoding(
|
162 |
+
name="cl100k_im",
|
163 |
+
pat_str=self.pat_str,
|
164 |
+
mergeable_ranks=self.mergeable_ranks,
|
165 |
+
special_tokens=self.special_tokens,
|
166 |
+
)
|
167 |
+
self.tokenizer = enc
|
168 |
+
|
169 |
+
def __len__(self):
|
170 |
+
return self.tokenizer.n_vocab
|
171 |
+
|
172 |
+
@classmethod
|
173 |
+
def from_pretrained(
|
174 |
+
cls,
|
175 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
176 |
+
*init_inputs,
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
+
cls_kwargs = kwargs
|
180 |
+
# First try to load from the tokenization config if it exists
|
181 |
+
tokenization_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
182 |
+
if tokenization_config:
|
183 |
+
cls_kwargs.update(
|
184 |
+
dict(
|
185 |
+
model_max_length=tokenization_config["model_max_length"],
|
186 |
+
chat_template=tokenization_config.get("chat_template", None)
|
187 |
+
)
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
191 |
+
cls_kwargs["model_max_length"] = config.max_position_embeddings
|
192 |
+
return cls(**cls_kwargs)
|
193 |
+
|
194 |
+
def get_vocab(self) -> Dict[Union[str, bytes], int]:
|
195 |
+
return {**self.mergeable_ranks, **self.special_tokens}
|
196 |
+
|
197 |
+
def convert_tokens_to_ids(
|
198 |
+
self,
|
199 |
+
tokens: Union[bytes, str, List[Union[bytes, str]]]
|
200 |
+
) -> Union[int, List[int]]:
|
201 |
+
ids = []
|
202 |
+
if isinstance(tokens, (str, bytes)):
|
203 |
+
if tokens in self.special_tokens:
|
204 |
+
return self.special_tokens[tokens]
|
205 |
+
else:
|
206 |
+
return self.mergeable_ranks.get(tokens)
|
207 |
+
ids: List[int] = []
|
208 |
+
for token in tokens:
|
209 |
+
ids.append(self.convert_tokens_to_ids(token))
|
210 |
+
return ids
|
211 |
+
|
212 |
+
def _add_tokens(
|
213 |
+
self,
|
214 |
+
new_tokens: Union[List[str], List[AddedToken]],
|
215 |
+
special_tokens: bool = False,
|
216 |
+
) -> int:
|
217 |
+
if not special_tokens and new_tokens:
|
218 |
+
raise ValueError("Only special tokens can be added to this tokenizer")
|
219 |
+
for token in new_tokens:
|
220 |
+
surface_form = token.content if isinstance(token, AddedToken) else token
|
221 |
+
if surface_form not in self.special_tokens:
|
222 |
+
raise ValueError(
|
223 |
+
"For now, we do not support unknown special tokens\n"
|
224 |
+
"In the future, if there is a need for this, we can add special tokens to the tokenizer\n"
|
225 |
+
"starting from rank 100261 - 100263 and then 100266 - 100275.\n"
|
226 |
+
"And finally, we can re-construct the enc object back\n"
|
227 |
+
)
|
228 |
+
return 0
|
229 |
+
|
230 |
+
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
231 |
+
file_path = os.path.join(save_directory, "cl100k_base.tiktoken")
|
232 |
+
with open(file_path, "w") as f:
|
233 |
+
for token, rank in self.mergeable_ranks.items():
|
234 |
+
line = base64.b64encode(token).decode("utf-8") + " " + str(rank) + "\n"
|
235 |
+
f.write(line)
|
236 |
+
return (file_path,)
|
237 |
+
|
238 |
+
def tokenize(
|
239 |
+
self,
|
240 |
+
text: str,
|
241 |
+
allowed_special: Union[Set, str] = "all",
|
242 |
+
disallowed_special: Union[Collection, str] = (),
|
243 |
+
**kwargs
|
244 |
+
) -> List[Union[bytes, str]]:
|
245 |
+
tokens: List[Union[bytes, str]] = []
|
246 |
+
for token_id in self.tokenizer.encode(
|
247 |
+
text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
248 |
+
):
|
249 |
+
tokens.append(self.decoder[token_id])
|
250 |
+
return tokens
|
251 |
+
|
252 |
+
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
|
253 |
+
"""
|
254 |
+
Converts a sequence of tokens in a single string.
|
255 |
+
"""
|
256 |
+
text = ""
|
257 |
+
temp = b""
|
258 |
+
for t in tokens:
|
259 |
+
if isinstance(t, str):
|
260 |
+
if temp:
|
261 |
+
text += temp.decode("utf-8", errors=self.errors)
|
262 |
+
temp = b""
|
263 |
+
text += t
|
264 |
+
elif isinstance(t, bytes):
|
265 |
+
temp += t
|
266 |
+
else:
|
267 |
+
raise TypeError("token should only be of type types or str")
|
268 |
+
if temp:
|
269 |
+
text += temp.decode("utf-8", errors=self.errors)
|
270 |
+
return text
|
271 |
+
|
272 |
+
@property
|
273 |
+
def vocab_size(self):
|
274 |
+
return self.tokenizer.n_vocab
|
275 |
+
|
276 |
+
@property
|
277 |
+
def eos_token_id(self) -> int:
|
278 |
+
return self.eod_id
|
279 |
+
|
280 |
+
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
|
281 |
+
"""Converts an id to a token, special tokens included"""
|
282 |
+
if index in self.decoder:
|
283 |
+
return self.decoder[index]
|
284 |
+
raise ValueError("unknown ids")
|
285 |
+
|
286 |
+
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
|
287 |
+
"""Converts a token to an id using the vocab, special tokens included"""
|
288 |
+
if token in self.special_tokens:
|
289 |
+
return self.special_tokens[token]
|
290 |
+
if token in self.mergeable_ranks:
|
291 |
+
return self.mergeable_ranks[token]
|
292 |
+
raise ValueError("unknown token")
|
293 |
+
|
294 |
+
def _tokenize(self, text: str, **kwargs):
|
295 |
+
"""
|
296 |
+
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
297 |
+
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
|
298 |
+
Do NOT take care of added tokens.
|
299 |
+
"""
|
300 |
+
raise NotImplementedError
|
301 |
+
|
302 |
+
def _decode(
|
303 |
+
self,
|
304 |
+
token_ids: Union[int, List[int]],
|
305 |
+
skip_special_tokens: bool = False,
|
306 |
+
errors: str = None,
|
307 |
+
**kwargs,
|
308 |
+
) -> str:
|
309 |
+
if isinstance(token_ids, int):
|
310 |
+
token_ids = [token_ids]
|
311 |
+
if skip_special_tokens:
|
312 |
+
token_ids = [i for i in token_ids if i < self.eod_id]
|
313 |
+
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
|
314 |
+
|
315 |
+
|
tokenizer_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {},
|
3 |
+
"auto_map": {
|
4 |
+
"AutoTokenizer": [
|
5 |
+
"tokenization_phi3_small.Phi3SmallTokenizer",
|
6 |
+
"tokenization_phi3_small.Phi3SmallTokenizer"
|
7 |
+
]
|
8 |
+
},
|
9 |
+
"bos_token": "<|endoftext|>",
|
10 |
+
"chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
11 |
+
"clean_up_tokenization_spaces": true,
|
12 |
+
"eos_token": "<|endoftext|>",
|
13 |
+
"model_max_length": 8192,
|
14 |
+
"pad_token": "<|endoftext|>",
|
15 |
+
"tokenizer_class": "Phi3SmallTokenizer"
|
16 |
+
}
|
triton_blocksparse_attention_layer.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple, TypeVar
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
|
7 |
+
from functools import lru_cache
|
8 |
+
|
9 |
+
|
10 |
+
from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd
|
11 |
+
|
12 |
+
|
13 |
+
Layout = Tuple[torch.LongTensor, torch.LongTensor]
|
14 |
+
|
15 |
+
|
16 |
+
def create_sparse_attn_mask(
|
17 |
+
n_heads: int,
|
18 |
+
max_seq_len: int,
|
19 |
+
max_seq_len_k: int,
|
20 |
+
dtype: torch.dtype,
|
21 |
+
device: torch.device,
|
22 |
+
BLOCK: int,
|
23 |
+
local_blocks: int,
|
24 |
+
vert_stride: int,
|
25 |
+
homo_head: bool,
|
26 |
+
return_dense: bool
|
27 |
+
) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]:
|
28 |
+
layout, block_sparse_pattern, _ = _get_sparse_attn_mask(
|
29 |
+
n_heads=n_heads,
|
30 |
+
q_len=max_seq_len,
|
31 |
+
N_CTX=max_seq_len_k,
|
32 |
+
dtype=dtype,
|
33 |
+
device=device,
|
34 |
+
BLOCK=BLOCK,
|
35 |
+
local_blocks=local_blocks,
|
36 |
+
vert_stride=vert_stride,
|
37 |
+
homo_head=homo_head,
|
38 |
+
return_dense=return_dense
|
39 |
+
)
|
40 |
+
return layout, block_sparse_pattern
|
41 |
+
|
42 |
+
|
43 |
+
class BlockSparseAttentionLayer(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
n_heads: int,
|
47 |
+
max_seq_len: int,
|
48 |
+
sparse_block_size: int,
|
49 |
+
local_blocks: int,
|
50 |
+
vert_stride: int,
|
51 |
+
kernel_block_size: Optional[int] = None,
|
52 |
+
homo_head: bool = False,
|
53 |
+
active_head_range: Optional[Tuple[int]] = None
|
54 |
+
) -> None:
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.n_heads = n_heads
|
58 |
+
self.max_seq_len = max_seq_len
|
59 |
+
self.sparse_block_size = sparse_block_size
|
60 |
+
self.kernel_block_size = kernel_block_size or sparse_block_size
|
61 |
+
self.local_blocks = local_blocks
|
62 |
+
self.vert_stride = vert_stride
|
63 |
+
self.homo_head = homo_head
|
64 |
+
self.active_head_range = active_head_range
|
65 |
+
|
66 |
+
# Internal Parameters used by the layer
|
67 |
+
self._sparse_block_mask = None
|
68 |
+
self._sparse_layout = None
|
69 |
+
self._dtype = None
|
70 |
+
self._device = None
|
71 |
+
|
72 |
+
# TODO(bapatra): Ideally, I'd want to keep all the code for
|
73 |
+
# forward to be handled here, and not branch for training and inference.
|
74 |
+
# However, that refactor would need a lot of testing. For now, using the
|
75 |
+
# training op as is, and will refactor again later.
|
76 |
+
|
77 |
+
def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None:
|
78 |
+
self._sparse_block_mask = self._sparse_block_mask[h_start: h_end]
|
79 |
+
self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end]
|
80 |
+
self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end]
|
81 |
+
|
82 |
+
def _initialize_internals(
|
83 |
+
self,
|
84 |
+
dtype: torch.dtype,
|
85 |
+
device: torch.device
|
86 |
+
) -> None:
|
87 |
+
self._dtype, self._device = dtype, device
|
88 |
+
self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask(
|
89 |
+
n_heads=self.n_heads,
|
90 |
+
max_seq_len=self.max_seq_len,
|
91 |
+
max_seq_len_k=self.max_seq_len,
|
92 |
+
dtype=dtype,
|
93 |
+
device=device,
|
94 |
+
BLOCK=self.sparse_block_size,
|
95 |
+
local_blocks=self.local_blocks,
|
96 |
+
vert_stride=self.vert_stride,
|
97 |
+
homo_head=self.homo_head,
|
98 |
+
return_dense=False,
|
99 |
+
)
|
100 |
+
if (not self.homo_head) and (self.active_head_range is not None):
|
101 |
+
assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads."
|
102 |
+
h_start, h_end = self.active_head_range
|
103 |
+
self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end)
|
104 |
+
|
105 |
+
assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}."
|
106 |
+
assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given"
|
107 |
+
if self.sparse_block_size // self.kernel_block_size > 1:
|
108 |
+
_mul = self.sparse_block_size // self.kernel_block_size
|
109 |
+
# need to consider if block_m and block_n are different
|
110 |
+
self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul))
|
111 |
+
num_sparse_blocks = self._sparse_block_mask.size(-1)
|
112 |
+
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
|
113 |
+
self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask)
|
114 |
+
|
115 |
+
|
116 |
+
def forward(
|
117 |
+
self,
|
118 |
+
q: torch.Tensor,
|
119 |
+
k: torch.Tensor,
|
120 |
+
v: torch.Tensor,
|
121 |
+
sm_scale: float,
|
122 |
+
*,
|
123 |
+
# Arguments Related to Block Attention Inference
|
124 |
+
left_paddings: Optional[torch.LongTensor] = None,
|
125 |
+
seqlens: Optional[torch.LongTensor] = None,
|
126 |
+
# Arguements Related to Variable Length Inference
|
127 |
+
cu_seqlens_k: Optional[torch.LongTensor] = None,
|
128 |
+
cu_seqlens_q: Optional[torch.LongTensor] = None,
|
129 |
+
) -> torch.Tensor:
|
130 |
+
|
131 |
+
if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None:
|
132 |
+
blocksparse_op = get_local_strided_sparse_attention_op(
|
133 |
+
n_heads=self.n_heads,
|
134 |
+
max_seq_len=self.max_seq_len,
|
135 |
+
sparse_block_size=self.sparse_block_size,
|
136 |
+
kernel_block_size=self.kernel_block_size,
|
137 |
+
local_blocks=self.local_blocks,
|
138 |
+
vert_stride=self.vert_stride,
|
139 |
+
homo_head=self.homo_head,
|
140 |
+
device=q.device,
|
141 |
+
inference=not self.training
|
142 |
+
)
|
143 |
+
return blocksparse_op(q, k, v, sm_scale)
|
144 |
+
|
145 |
+
assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context"
|
146 |
+
# First set internals if they have not been set
|
147 |
+
if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device):
|
148 |
+
self._initialize_internals(dtype=q.dtype, device=q.device)
|
149 |
+
|
150 |
+
if k.dim() == 3:
|
151 |
+
assert cu_seqlens_k is not None
|
152 |
+
return blocksparse_flash_attn_varlen_fwd(
|
153 |
+
q=q,
|
154 |
+
k=k,
|
155 |
+
v=v,
|
156 |
+
cu_seqlens_k=cu_seqlens_k,
|
157 |
+
cu_seqlens_q=cu_seqlens_q,
|
158 |
+
sm_scale=sm_scale,
|
159 |
+
sparse_layout=self._sparse_layout,
|
160 |
+
block_size=self.kernel_block_size,
|
161 |
+
max_seqlen=self.max_seq_len,
|
162 |
+
)
|
163 |
+
if k.dim() == 4:
|
164 |
+
assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference."
|
165 |
+
return blocksparse_flash_attn_padded_fwd(
|
166 |
+
q=q,
|
167 |
+
k=k,
|
168 |
+
v=v,
|
169 |
+
sm_scale=sm_scale,
|
170 |
+
sparse_layout=self._sparse_layout,
|
171 |
+
left_paddings=left_paddings,
|
172 |
+
seqlens=seqlens,
|
173 |
+
block_size=self.kernel_block_size,
|
174 |
+
max_seqlen=self.max_seq_len,
|
175 |
+
)
|
176 |
+
raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.')
|
triton_flash_blocksparse_attn.py
ADDED
@@ -0,0 +1,1943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Eric Lin (xihlin)
|
3 |
+
"""
|
4 |
+
"""
|
5 |
+
... note(bapatra)::
|
6 |
+
This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module
|
7 |
+
imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal.
|
8 |
+
In the future, would be really good to revisit this and refactor into a more readable file structure.
|
9 |
+
|
10 |
+
"""
|
11 |
+
from typing import TypeVar
|
12 |
+
from functools import lru_cache
|
13 |
+
import math
|
14 |
+
import pytest
|
15 |
+
import torch
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import triton
|
19 |
+
import triton.language as tl
|
20 |
+
|
21 |
+
import os
|
22 |
+
|
23 |
+
import dataclasses
|
24 |
+
|
25 |
+
Phi3SmallConfig = TypeVar('Phi3SmallConfig')
|
26 |
+
|
27 |
+
# triton 2.0.0: fail at backward on A100, for the examples, if h_dim=128.
|
28 |
+
|
29 |
+
# Done
|
30 |
+
# 1. strided of qkv
|
31 |
+
# 2. seq len not power of 2
|
32 |
+
# 3. bf16 with Triton May, 2023
|
33 |
+
|
34 |
+
# TODO:
|
35 |
+
# 1. wip: support non-contiguous backward, also help reduce memory allocation in training (q, k, v split)
|
36 |
+
# 2. block sparse with different BLOCK_M, BLOCK_N?
|
37 |
+
# 3. for Lq not divided by BLOCK_M, BLOCK_N, only apply mask to K/V on last batch, still need to apply mask on Q.
|
38 |
+
# Attempt, fail to compile
|
39 |
+
# 4. For 2nd iter of inference, BLOCK_M=1, how to make things work? K/V maynot divided by BLOCK_N.
|
40 |
+
# 5. The inner loop can also be paralled via bigger num_stage(better) or on different thread-block (via m/L and atomic update, but this no-comm/sync between blocks)
|
41 |
+
|
42 |
+
|
43 |
+
###########################################################
|
44 |
+
################### Kernel Parameters #####################
|
45 |
+
###########################################################
|
46 |
+
|
47 |
+
@dataclasses.dataclass
|
48 |
+
class BlockSparseParams(object):
|
49 |
+
block_size: int
|
50 |
+
kernel_block_size: int
|
51 |
+
num_local_blocks: int
|
52 |
+
vert_stride: int
|
53 |
+
homo_head_pattern: bool = False
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams":
|
57 |
+
return cls(
|
58 |
+
block_size=config.blocksparse_block_size,
|
59 |
+
kernel_block_size=config.blocksparse_triton_kernel_block_size,
|
60 |
+
num_local_blocks=config.blocksparse_num_local_blocks,
|
61 |
+
vert_stride=config.blocksparse_vert_stride,
|
62 |
+
homo_head_pattern=config.blocksparse_homo_head_pattern,
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
###########################################################
|
67 |
+
###########################################################
|
68 |
+
|
69 |
+
###########################################################
|
70 |
+
################### Utility Functions #####################
|
71 |
+
###########################################################
|
72 |
+
|
73 |
+
# helper functions for 3D sparse pattern
|
74 |
+
# these function are not optimized and very inefficient. Avoid calling them too frequent.
|
75 |
+
# currently, it is only called within `get_local_strided_sparse_attention_op`, which is cached.
|
76 |
+
def dense_to_crow_col(x):
|
77 |
+
''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
|
78 |
+
param:
|
79 |
+
TODO:
|
80 |
+
1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it?
|
81 |
+
NOTE: col_indices padded -1
|
82 |
+
'''
|
83 |
+
pad = -1
|
84 |
+
dim = x.dim()
|
85 |
+
assert x.dim() in (2, 3)
|
86 |
+
if x.dim() == 2:
|
87 |
+
x = x[None]
|
88 |
+
x = [xi.to_sparse_csr() for xi in x]
|
89 |
+
crows = torch.vstack([xi.crow_indices() for xi in x])
|
90 |
+
cols = [xi.col_indices() for xi in x]
|
91 |
+
max_cols = max(len(xi) for xi in cols)
|
92 |
+
cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols]
|
93 |
+
cols = torch.vstack(cols)
|
94 |
+
if dim == 2:
|
95 |
+
crows = crows[0]
|
96 |
+
cols = cols[0]
|
97 |
+
return crows, cols
|
98 |
+
|
99 |
+
|
100 |
+
def crow_col_to_dense(crows, cols, dtype=torch.float16):
|
101 |
+
dim = crows.dim()
|
102 |
+
if dim == 1:
|
103 |
+
crows = crows[None]
|
104 |
+
cols = cols[None]
|
105 |
+
device = crows.device
|
106 |
+
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
|
107 |
+
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
|
108 |
+
x = torch.zeros(shape, dtype=dtype)
|
109 |
+
for i in range(shape[0]):
|
110 |
+
for j in range(shape[1]):
|
111 |
+
x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1
|
112 |
+
if dim == 1:
|
113 |
+
x = x[0]
|
114 |
+
return x.to(device)
|
115 |
+
|
116 |
+
|
117 |
+
def dense_to_ccol_row(x):
|
118 |
+
'''Similar, but to CSC format
|
119 |
+
'''
|
120 |
+
x = x.transpose(-2, -1)
|
121 |
+
return dense_to_crow_col(x)
|
122 |
+
|
123 |
+
|
124 |
+
def ccol_row_to_dense(ccol, rows, dtype=torch.float16):
|
125 |
+
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
|
126 |
+
|
127 |
+
|
128 |
+
def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False):
|
129 |
+
'''
|
130 |
+
:return: a tuple of 3:
|
131 |
+
- tuple of crow_indices, col_indices representation of CSR format.
|
132 |
+
- block dense mask
|
133 |
+
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
|
134 |
+
'''
|
135 |
+
with torch.no_grad():
|
136 |
+
N_BLOCK = triton.cdiv(N_CTX, BLOCK)
|
137 |
+
q_pos = torch.arange(N_BLOCK)[:, None]
|
138 |
+
k_pos = torch.arange(N_BLOCK)[None]
|
139 |
+
mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0
|
140 |
+
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
|
141 |
+
N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
|
142 |
+
block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr()
|
143 |
+
if return_dense:
|
144 |
+
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
|
145 |
+
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
|
146 |
+
mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask
|
147 |
+
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense
|
148 |
+
else:
|
149 |
+
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None
|
150 |
+
|
151 |
+
|
152 |
+
def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False):
|
153 |
+
'''
|
154 |
+
:return: a tuple of 3:
|
155 |
+
- tuple of crow_indices, col_indices representation of CSR format.
|
156 |
+
- block dense mask
|
157 |
+
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
|
158 |
+
'''
|
159 |
+
if homo_head:
|
160 |
+
with torch.no_grad():
|
161 |
+
(crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense)
|
162 |
+
crow = crow[None].expand(n_heads, crow.shape[0])
|
163 |
+
col = col[None].expand(n_heads, col.shape[0])
|
164 |
+
if return_dense:
|
165 |
+
mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape)
|
166 |
+
return (crow, col), block_mask_dense, mask_dense
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
N_BLOCK = triton.cdiv(N_CTX, BLOCK)
|
170 |
+
q_pos = torch.arange(N_BLOCK)[None, :, None]
|
171 |
+
k_pos = torch.arange(N_BLOCK)[None, None]
|
172 |
+
head_sliding_step = max(1, int(vert_stride / n_heads)) # if vert_stride <= n_heads, rotating the heads
|
173 |
+
mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)]
|
174 |
+
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
175 |
+
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
|
176 |
+
N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
|
177 |
+
block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:]
|
178 |
+
if return_dense:
|
179 |
+
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
|
180 |
+
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
|
181 |
+
mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None]
|
182 |
+
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense
|
183 |
+
else:
|
184 |
+
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None
|
185 |
+
|
186 |
+
|
187 |
+
def get_sparse_attn_mask(q, N_CTX, *args, **kwargs):
|
188 |
+
return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs)
|
189 |
+
|
190 |
+
###########################################################
|
191 |
+
###########################################################
|
192 |
+
|
193 |
+
###########################################################
|
194 |
+
###################### Training Kernels ###################
|
195 |
+
###########################################################
|
196 |
+
|
197 |
+
# TODO: only apply loading/saving mask on the last iteration for EVEN_N_BLOCK, useful for 1st iteration of inference.
|
198 |
+
# Experiment failed inside loop.
|
199 |
+
# Another idea: only on saving? load even out of boundary(will it causes illegal access error)?
|
200 |
+
@triton.jit
|
201 |
+
def _fwd_kernel(
|
202 |
+
Q, K, V, sm_scale,
|
203 |
+
layout_crow_ptr,
|
204 |
+
layout_col_ptr,
|
205 |
+
layout_crow_stride_h, layout_crow_stride_m,
|
206 |
+
layout_col_stride_h, layout_col_stride_m,
|
207 |
+
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug. TMP, L, M are assumed to have contiguous layouts
|
208 |
+
Out,
|
209 |
+
stride_qz, stride_qh, stride_qm, stride_qd,
|
210 |
+
stride_kz, stride_kh, stride_kn, stride_kd,
|
211 |
+
stride_vz, stride_vh, stride_vn, stride_vd,
|
212 |
+
stride_oz, stride_oh, stride_om, stride_od,
|
213 |
+
Z, H, N_CTX,
|
214 |
+
PAST_LEN,
|
215 |
+
Q_ROUNDED_LEN,
|
216 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
217 |
+
BLOCK_N: tl.constexpr,
|
218 |
+
EVEN_M_BLOCK: tl.constexpr,
|
219 |
+
EVEN_N_BLOCK: tl.constexpr,
|
220 |
+
INFERENCE: tl.constexpr,
|
221 |
+
NUM_DBLOCKS: tl.constexpr,
|
222 |
+
):
|
223 |
+
Q_LEN = N_CTX - PAST_LEN
|
224 |
+
start_m = tl.program_id(0)
|
225 |
+
off_hz = tl.program_id(1)
|
226 |
+
off_h = off_hz % H
|
227 |
+
off_z = off_hz // H
|
228 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
229 |
+
K += off_z * stride_kz + off_h * stride_kh
|
230 |
+
V += off_z * stride_vz + off_h * stride_vh
|
231 |
+
# initialize offsets
|
232 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
233 |
+
offs_n = tl.arange(0, BLOCK_N)
|
234 |
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
235 |
+
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
|
236 |
+
# off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
|
237 |
+
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
|
238 |
+
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
|
239 |
+
# Initialize pointers to Q, K, V
|
240 |
+
q_ptrs = Q + off_q
|
241 |
+
k_ptrs = K + off_k
|
242 |
+
v_ptrs = V + off_v
|
243 |
+
# initialize pointer to m and l
|
244 |
+
t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m
|
245 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
|
246 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
247 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
248 |
+
if NUM_DBLOCKS >= 2:
|
249 |
+
acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
250 |
+
|
251 |
+
# load q: it will stay in SRAM throughout
|
252 |
+
if EVEN_M_BLOCK:
|
253 |
+
q = tl.load(q_ptrs)
|
254 |
+
if NUM_DBLOCKS >= 2:
|
255 |
+
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
|
256 |
+
else:
|
257 |
+
q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)
|
258 |
+
if NUM_DBLOCKS >= 2:
|
259 |
+
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN)
|
260 |
+
|
261 |
+
layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m
|
262 |
+
start_l = tl.load(layout_ptr).to(tl.int32)
|
263 |
+
end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32)
|
264 |
+
|
265 |
+
# loop over k, v and update accumulator
|
266 |
+
for col_idx_idx in range(start_l, end_l):
|
267 |
+
col_idx = tl.load(layout_col_ptr + off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32)
|
268 |
+
start_n = col_idx * BLOCK_N
|
269 |
+
# -- compute qk ----
|
270 |
+
if EVEN_N_BLOCK:
|
271 |
+
k = tl.load(k_ptrs + start_n * stride_kn)
|
272 |
+
else:
|
273 |
+
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX)
|
274 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
275 |
+
qk += tl.dot(q, k)
|
276 |
+
|
277 |
+
if NUM_DBLOCKS >= 2:
|
278 |
+
if EVEN_N_BLOCK:
|
279 |
+
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd)
|
280 |
+
else:
|
281 |
+
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX)
|
282 |
+
qk += tl.dot(q2, k)
|
283 |
+
|
284 |
+
qk *= sm_scale
|
285 |
+
qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf'))
|
286 |
+
# -- compute m_ij, p, l_ij
|
287 |
+
m_ij = tl.max(qk, 1)
|
288 |
+
p = tl.exp(qk - m_ij[:, None])
|
289 |
+
l_ij = tl.sum(p, 1)
|
290 |
+
# -- update m_i and l_i
|
291 |
+
m_i_new = tl.maximum(m_i, m_ij)
|
292 |
+
alpha = tl.exp(m_i - m_i_new)
|
293 |
+
beta = tl.exp(m_ij - m_i_new)
|
294 |
+
l_i_new = alpha * l_i + beta * l_ij
|
295 |
+
# -- update output accumulator --
|
296 |
+
# scale p
|
297 |
+
p_scale = beta / l_i_new
|
298 |
+
p = p * p_scale[:, None]
|
299 |
+
# scale acc
|
300 |
+
acc_scale = l_i / l_i_new * alpha
|
301 |
+
# tl.store(t_ptrs, acc_scale)
|
302 |
+
# acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
303 |
+
acc = acc * acc_scale[:, None]
|
304 |
+
if NUM_DBLOCKS >= 2:
|
305 |
+
acc2 = acc2 * acc_scale[:, None]
|
306 |
+
p = p.to(Q.dtype.element_ty)
|
307 |
+
# update acc
|
308 |
+
if EVEN_N_BLOCK:
|
309 |
+
v = tl.load(v_ptrs + start_n * stride_vn)
|
310 |
+
else:
|
311 |
+
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX)
|
312 |
+
acc += tl.dot(p, v)
|
313 |
+
|
314 |
+
if NUM_DBLOCKS >= 2:
|
315 |
+
if EVEN_N_BLOCK:
|
316 |
+
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd)
|
317 |
+
else:
|
318 |
+
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX)
|
319 |
+
acc2 += tl.dot(p, v)
|
320 |
+
|
321 |
+
# update m_i and l_i
|
322 |
+
l_i = l_i_new
|
323 |
+
m_i = m_i_new
|
324 |
+
|
325 |
+
# rematerialize offsets to save registers
|
326 |
+
# start_m = tl.program_id(0)
|
327 |
+
# offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
328 |
+
# write back l and m
|
329 |
+
if not INFERENCE:
|
330 |
+
l_ptrs = L + off_hz * N_CTX + offs_m
|
331 |
+
m_ptrs = M + off_hz * N_CTX + offs_m
|
332 |
+
if EVEN_M_BLOCK:
|
333 |
+
tl.store(l_ptrs, l_i)
|
334 |
+
tl.store(m_ptrs, m_i)
|
335 |
+
else:
|
336 |
+
tl.store(l_ptrs, l_i, mask=offs_m < Q_LEN)
|
337 |
+
tl.store(m_ptrs, m_i, mask=offs_m < Q_LEN)
|
338 |
+
# initialize pointers to output
|
339 |
+
# offs_n = tl.arange(0, BLOCK_DMODEL)
|
340 |
+
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
|
341 |
+
out_ptrs = Out + off_o
|
342 |
+
tl.store(out_ptrs, acc, mask=offs_m[:, None] < Q_LEN)
|
343 |
+
if NUM_DBLOCKS >= 2:
|
344 |
+
tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2, mask=offs_m[:, None] < Q_LEN)
|
345 |
+
|
346 |
+
|
347 |
+
## backward
|
348 |
+
@triton.heuristics(
|
349 |
+
{
|
350 |
+
'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
|
351 |
+
}
|
352 |
+
)
|
353 |
+
@triton.jit
|
354 |
+
def _bwd_preprocess(
|
355 |
+
Out, DO, L, # assume contiguous for Out, DO, L, NewDO, Delta layout.
|
356 |
+
NewDO, Delta,
|
357 |
+
N_CTX,
|
358 |
+
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
359 |
+
EVEN_M_BLOCK: tl.constexpr,
|
360 |
+
):
|
361 |
+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
362 |
+
off_d = tl.arange(0, D_HEAD)
|
363 |
+
# load
|
364 |
+
if EVEN_M_BLOCK:
|
365 |
+
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
|
366 |
+
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
|
367 |
+
else:
|
368 |
+
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
|
369 |
+
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
|
370 |
+
denom = tl.load(L + off_m).to(tl.float32)
|
371 |
+
# compute
|
372 |
+
do = do / denom[:, None]
|
373 |
+
delta = tl.sum(o * do, axis=1)
|
374 |
+
# write-back
|
375 |
+
if EVEN_M_BLOCK:
|
376 |
+
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do)
|
377 |
+
else:
|
378 |
+
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do, mask=off_m[:, None] < N_CTX)
|
379 |
+
tl.store(Delta + off_m, delta)
|
380 |
+
|
381 |
+
|
382 |
+
# Does not suuport unequal seqlen(q) and seqlen(k)
|
383 |
+
@triton.heuristics(
|
384 |
+
{
|
385 |
+
'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
|
386 |
+
'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0,
|
387 |
+
}
|
388 |
+
)
|
389 |
+
@triton.jit
|
390 |
+
def _bwd_kernel(
|
391 |
+
Q, K, V, sm_scale,
|
392 |
+
layout_ccol_ptr,
|
393 |
+
layout_row_ptr,
|
394 |
+
layout_ccol_stride_h, layout_ccol_stride_m,
|
395 |
+
layout_row_stride_h, layout_row_stride_m,
|
396 |
+
Out, DO, # assume contigous: Out, Do, DQ, DK, DV, L, M, D, seq(q) == seq(k), with stride_oz, stride_oh, stride_om, stride_od,
|
397 |
+
DQ, DK, DV,
|
398 |
+
L, M,
|
399 |
+
D,
|
400 |
+
stride_qz, stride_qh, stride_qm, stride_qd,
|
401 |
+
stride_kz, stride_kh, stride_kn, stride_kd,
|
402 |
+
stride_vz, stride_vh, stride_vn, stride_vd,
|
403 |
+
stride_oz, stride_oh, stride_om, stride_od,
|
404 |
+
# stride_dz, stride_dh, stride_dm, stride_dd,
|
405 |
+
Z, H, N_CTX,
|
406 |
+
num_block,
|
407 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
408 |
+
BLOCK_N: tl.constexpr,
|
409 |
+
EVEN_M_BLOCK: tl.constexpr,
|
410 |
+
EVEN_N_BLOCK: tl.constexpr,
|
411 |
+
NUM_DBLOCKS: tl.constexpr,
|
412 |
+
):
|
413 |
+
start_n = tl.program_id(0)
|
414 |
+
off_hz = tl.program_id(1)
|
415 |
+
off_z = off_hz // H
|
416 |
+
off_h = off_hz % H
|
417 |
+
# offset pointers for batch/head
|
418 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
419 |
+
K += off_z * stride_kz + off_h * stride_kh
|
420 |
+
V += off_z * stride_vz + off_h * stride_vh
|
421 |
+
DO += off_z * stride_oz + off_h * stride_oh
|
422 |
+
DQ += off_z * stride_oz + off_h * stride_oh
|
423 |
+
DK += off_z * stride_oz + off_h * stride_oh
|
424 |
+
DV += off_z * stride_oz + off_h * stride_oh
|
425 |
+
# Look like this loop can be parallelled
|
426 |
+
# for start_n in range(0, num_block):
|
427 |
+
|
428 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
429 |
+
offs_m = tl.arange(0, BLOCK_M)
|
430 |
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
431 |
+
# initialize pointers to value-like data
|
432 |
+
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
|
433 |
+
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)
|
434 |
+
|
435 |
+
# pointer to row-wise quantities in value-like data
|
436 |
+
D_ptrs = D + off_hz * N_CTX
|
437 |
+
m_ptrs = M + off_hz * N_CTX
|
438 |
+
# initialize dv amd dk
|
439 |
+
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
440 |
+
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
441 |
+
# k and v stay in SRAM throughout
|
442 |
+
if EVEN_N_BLOCK:
|
443 |
+
k = tl.load(k_ptrs)
|
444 |
+
v = tl.load(v_ptrs)
|
445 |
+
else:
|
446 |
+
k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX)
|
447 |
+
v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX)
|
448 |
+
|
449 |
+
if NUM_DBLOCKS >= 2:
|
450 |
+
dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
451 |
+
dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
452 |
+
if EVEN_N_BLOCK:
|
453 |
+
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd)
|
454 |
+
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd)
|
455 |
+
else:
|
456 |
+
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX)
|
457 |
+
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX)
|
458 |
+
|
459 |
+
# loop over rows
|
460 |
+
|
461 |
+
layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m
|
462 |
+
start_l = tl.load(layout_ptr).to(tl.int32)
|
463 |
+
end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32)
|
464 |
+
|
465 |
+
for row_idx_idx in range(start_l, end_l):
|
466 |
+
row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32)
|
467 |
+
start_m = row_idx * BLOCK_M
|
468 |
+
|
469 |
+
# offs_qm = start_m + tl.arange(0, BLOCK_M)
|
470 |
+
offs_m_curr = start_m + offs_m
|
471 |
+
q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd)
|
472 |
+
do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
|
473 |
+
dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
|
474 |
+
|
475 |
+
# load q, k, v, do on-chip
|
476 |
+
if EVEN_M_BLOCK:
|
477 |
+
q = tl.load(q_ptrs)
|
478 |
+
else:
|
479 |
+
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX)
|
480 |
+
# re-compute p = softmax(qk, dim=-1).T
|
481 |
+
# NOTE: `do` is pre-divided by `l`; no normalization here
|
482 |
+
qk = tl.dot(q, tl.trans(k))
|
483 |
+
|
484 |
+
if NUM_DBLOCKS >= 2:
|
485 |
+
if EVEN_M_BLOCK:
|
486 |
+
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
|
487 |
+
else:
|
488 |
+
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX)
|
489 |
+
qk += tl.dot(q2, tl.trans(k2))
|
490 |
+
|
491 |
+
qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf'))
|
492 |
+
|
493 |
+
if EVEN_M_BLOCK:
|
494 |
+
m = tl.load(m_ptrs + offs_m_curr)
|
495 |
+
else:
|
496 |
+
m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
|
497 |
+
p = tl.exp(qk * sm_scale - m[:, None])
|
498 |
+
|
499 |
+
# compute dv
|
500 |
+
if EVEN_M_BLOCK:
|
501 |
+
do = tl.load(do_ptrs)
|
502 |
+
else:
|
503 |
+
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX)
|
504 |
+
|
505 |
+
if NUM_DBLOCKS >= 2:
|
506 |
+
if EVEN_M_BLOCK:
|
507 |
+
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od)
|
508 |
+
else:
|
509 |
+
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX)
|
510 |
+
|
511 |
+
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
512 |
+
|
513 |
+
if NUM_DBLOCKS >= 2:
|
514 |
+
dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2)
|
515 |
+
|
516 |
+
# compute dp = dot(v, do)
|
517 |
+
if EVEN_M_BLOCK:
|
518 |
+
Di = tl.load(D_ptrs + offs_m_curr)
|
519 |
+
else:
|
520 |
+
Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
|
521 |
+
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
522 |
+
dp += tl.dot(do, tl.trans(v))
|
523 |
+
|
524 |
+
if NUM_DBLOCKS >= 2:
|
525 |
+
dp += tl.dot(do2, tl.trans(v2))
|
526 |
+
|
527 |
+
# compute ds = p * (dp - delta[:, None])
|
528 |
+
ds = p * dp * sm_scale
|
529 |
+
# compute dk = dot(ds.T, q)
|
530 |
+
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
531 |
+
if NUM_DBLOCKS >= 2:
|
532 |
+
dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2)
|
533 |
+
|
534 |
+
# # compute dq
|
535 |
+
dq = tl.dot(ds.to(Q.dtype.element_ty), k)
|
536 |
+
if EVEN_M_BLOCK:
|
537 |
+
tl.atomic_add(dq_ptrs, dq)
|
538 |
+
else:
|
539 |
+
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX)
|
540 |
+
|
541 |
+
if NUM_DBLOCKS >= 2:
|
542 |
+
dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2)
|
543 |
+
dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od
|
544 |
+
if EVEN_M_BLOCK:
|
545 |
+
tl.atomic_add(dq_ptrs2, dq2)
|
546 |
+
else:
|
547 |
+
tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX)
|
548 |
+
|
549 |
+
# write-back
|
550 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
|
551 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
|
552 |
+
if EVEN_N_BLOCK:
|
553 |
+
tl.store(dv_ptrs, dv)
|
554 |
+
tl.store(dk_ptrs, dk)
|
555 |
+
else:
|
556 |
+
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX)
|
557 |
+
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX)
|
558 |
+
|
559 |
+
if NUM_DBLOCKS >= 2:
|
560 |
+
dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od
|
561 |
+
dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od
|
562 |
+
if EVEN_N_BLOCK:
|
563 |
+
tl.store(dv_ptrs2, dv2)
|
564 |
+
tl.store(dk_ptrs2, dk2)
|
565 |
+
else:
|
566 |
+
tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX)
|
567 |
+
tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX)
|
568 |
+
|
569 |
+
|
570 |
+
|
571 |
+
def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None):
|
572 |
+
'''
|
573 |
+
:param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v.
|
574 |
+
:param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor.
|
575 |
+
Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all..
|
576 |
+
'''
|
577 |
+
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
578 |
+
assert k.shape[2] == v.shape[2]
|
579 |
+
o = out if out is not None else torch.empty_like(q).contiguous()
|
580 |
+
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
|
581 |
+
|
582 |
+
q_rounded_len = grid[0] * BLOCK_M
|
583 |
+
tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
|
584 |
+
|
585 |
+
if inference is None:
|
586 |
+
inference = (not q.requires_grad) and (not k.requires_grad) and (not v.requires_grad)
|
587 |
+
|
588 |
+
if inference:
|
589 |
+
L, m = tmp, tmp # no need to use create new tensor
|
590 |
+
else:
|
591 |
+
L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
|
592 |
+
m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
|
593 |
+
|
594 |
+
if layout_col_indices.dim() == 1:
|
595 |
+
layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1)
|
596 |
+
layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1)
|
597 |
+
|
598 |
+
assert q.shape[-1] in [64, 128]
|
599 |
+
BLOCK_DMODEL = 64
|
600 |
+
|
601 |
+
if num_warps is None:
|
602 |
+
MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL)
|
603 |
+
num_warps = max(1, 2 ** int(math.log2(MIN_D / 16)))
|
604 |
+
# print(f'> {BLOCK_M=}, {BLOCK_N=}, {BLOCK_DMODEL=}, {num_warps=}, {num_stages=}')
|
605 |
+
else:
|
606 |
+
assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.'''
|
607 |
+
|
608 |
+
## For debugging:
|
609 |
+
# print(f'>> {q.shape=}, {k.shape=}, {BLOCK_M=}, {BLOCK_N=}, {num_warps=}, {BLOCK_DMODEL=}, {q.stride()=}, {k.stride()=}')
|
610 |
+
# print(f'>> {layout_crow_indices=}\n{layout_col_indices=}\n {layout_crow_indices.stride()=}, {layout_crow_indices.stride()=}')
|
611 |
+
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
612 |
+
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
613 |
+
|
614 |
+
_fwd_kernel[grid](
|
615 |
+
q, k, v, sm_scale,
|
616 |
+
layout_crow_indices,
|
617 |
+
layout_col_indices,
|
618 |
+
layout_crow_indices.stride(0), layout_crow_indices.stride(1),
|
619 |
+
layout_col_indices.stride(0), layout_col_indices.stride(1),
|
620 |
+
tmp, L, m,
|
621 |
+
o,
|
622 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
623 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
624 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
625 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
626 |
+
q.shape[0], q.shape[1], k.shape[2],
|
627 |
+
k.shape[2] - q.shape[2],
|
628 |
+
q_rounded_len,
|
629 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
630 |
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
631 |
+
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
|
632 |
+
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
|
633 |
+
INFERENCE=inference,
|
634 |
+
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
|
635 |
+
num_warps=num_warps,
|
636 |
+
num_stages=num_stages,
|
637 |
+
)
|
638 |
+
if inference:
|
639 |
+
L, m = None, None
|
640 |
+
|
641 |
+
ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices, layout_col_indices)
|
642 |
+
ctx.BLOCK_M = BLOCK_M
|
643 |
+
ctx.BLOCK_N = BLOCK_N
|
644 |
+
ctx.BLOCK_DMODEL = BLOCK_DMODEL
|
645 |
+
# ctx.BLOCK = BLOCK
|
646 |
+
ctx.grid = grid
|
647 |
+
ctx.sm_scale = sm_scale
|
648 |
+
ctx.num_warps = num_warps
|
649 |
+
ctx.num_stages = num_stages
|
650 |
+
return o
|
651 |
+
|
652 |
+
|
653 |
+
def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None):
|
654 |
+
# q, k, v, o, l, m = ctx.saved_tensors
|
655 |
+
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
|
656 |
+
|
657 |
+
## this following too slow to do online, so get it from inputs, which is cached.
|
658 |
+
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
|
659 |
+
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
|
660 |
+
|
661 |
+
if not do.is_contiguous():
|
662 |
+
do = do.contiguous()
|
663 |
+
## for debugging
|
664 |
+
# print(f'----> do is not contiguous: {do.stride()=}')
|
665 |
+
# raise ValueError(f'>>>> output grad is not contiguous: {do.stride()=}')
|
666 |
+
|
667 |
+
if not o.is_contiguous():
|
668 |
+
# TODO: currently only work with contiguous q/k/v.
|
669 |
+
raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.')
|
670 |
+
|
671 |
+
|
672 |
+
if layout_ccol_indices.dim() == 1:
|
673 |
+
layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1)
|
674 |
+
layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1)
|
675 |
+
|
676 |
+
# do = do.contiguous()
|
677 |
+
dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32)
|
678 |
+
dk = dk if dk is not None else torch.empty_like(k)
|
679 |
+
dv =dv if dv is not None else torch.empty_like(v)
|
680 |
+
do_scaled = torch.empty_like(do)
|
681 |
+
delta = torch.empty_like(l)
|
682 |
+
|
683 |
+
assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride()
|
684 |
+
|
685 |
+
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
686 |
+
o, do, l,
|
687 |
+
do_scaled, delta,
|
688 |
+
k.shape[2],
|
689 |
+
BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1],
|
690 |
+
)
|
691 |
+
|
692 |
+
grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1])
|
693 |
+
|
694 |
+
_bwd_kernel[grid](
|
695 |
+
q, k, v, ctx.sm_scale,
|
696 |
+
layout_ccol_indices,
|
697 |
+
layout_row_indices,
|
698 |
+
layout_ccol_indices.stride(0), layout_ccol_indices.stride(1),
|
699 |
+
layout_row_indices.stride(0), layout_row_indices.stride(1),
|
700 |
+
o, do_scaled,
|
701 |
+
dq, dk, dv,
|
702 |
+
l, m,
|
703 |
+
delta,
|
704 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
705 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
706 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
707 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
708 |
+
q.shape[0], q.shape[1], q.shape[2],
|
709 |
+
ctx.grid[0],
|
710 |
+
BLOCK_M=ctx.BLOCK_M,
|
711 |
+
BLOCK_N=ctx.BLOCK_N,
|
712 |
+
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
713 |
+
NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL,
|
714 |
+
num_warps=ctx.num_warps,
|
715 |
+
num_stages=1,
|
716 |
+
)
|
717 |
+
return dq, dk, dv, None, None, None
|
718 |
+
|
719 |
+
|
720 |
+
class _sparse_attention(torch.autograd.Function):
|
721 |
+
|
722 |
+
@staticmethod
|
723 |
+
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
|
724 |
+
BLOCK = 128
|
725 |
+
# shape constraints
|
726 |
+
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK)
|
727 |
+
|
728 |
+
@staticmethod
|
729 |
+
def backward(ctx, do):
|
730 |
+
# q, k, v, o, l, m = ctx.saved_tensors
|
731 |
+
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
|
732 |
+
# TODO: the following is very inefficient.
|
733 |
+
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
|
734 |
+
layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
|
735 |
+
return _backward(ctx, do, layout_ccol_indices, layout_row_indices)
|
736 |
+
|
737 |
+
|
738 |
+
|
739 |
+
# suppressed
|
740 |
+
class _sparse_attention_inference(_sparse_attention):
|
741 |
+
# TODO: does not work now, as BLOCK_M cannot be <1, as shape for tl.dot cannot be smaller than 16.
|
742 |
+
@staticmethod
|
743 |
+
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
|
744 |
+
BLOCK = 128
|
745 |
+
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK)
|
746 |
+
|
747 |
+
|
748 |
+
|
749 |
+
def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs):
|
750 |
+
class _sparse_attention_config(_sparse_attention):
|
751 |
+
@staticmethod
|
752 |
+
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
|
753 |
+
# shape constraints
|
754 |
+
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
|
755 |
+
**kwargs
|
756 |
+
)
|
757 |
+
return _sparse_attention_config.apply
|
758 |
+
|
759 |
+
|
760 |
+
@lru_cache(maxsize=8)
|
761 |
+
def get_local_strided_sparse_attention_op(
|
762 |
+
n_heads: int,
|
763 |
+
max_seq_len:int,
|
764 |
+
sparse_block_size: int=128,
|
765 |
+
local_blocks: int=4,
|
766 |
+
vert_stride: int=4,
|
767 |
+
homo_head: bool=False,
|
768 |
+
dtype=torch.bfloat16,
|
769 |
+
device='cuda',
|
770 |
+
active_head_range=None,
|
771 |
+
verbose=True,
|
772 |
+
**kwargs):
|
773 |
+
'''
|
774 |
+
:param n_heads: total number of attention heads (regardless of tensor/model parallel)
|
775 |
+
:param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences.
|
776 |
+
:param sparse_block_size: sparse block size. Default to 128
|
777 |
+
:param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens.
|
778 |
+
:param vert_stride: Default to 4. Meaning
|
779 |
+
:param homo_head: if all head shared the same pattern.
|
780 |
+
:param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads.
|
781 |
+
Mainly for tensor/model parallelization where heads are splitted to different GPUs.
|
782 |
+
'''
|
783 |
+
|
784 |
+
if verbose:
|
785 |
+
print((f'> new block_sparse_attn op constructed with config: '
|
786 |
+
f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, '
|
787 |
+
f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}'))
|
788 |
+
# assert math.log2(max_seq_len) % 2 == 0, f"max_seq_len should be power of 2 to be more efficient"
|
789 |
+
_, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device,
|
790 |
+
BLOCK=sparse_block_size, local_blocks=local_blocks,
|
791 |
+
vert_stride=vert_stride, homo_head=homo_head,
|
792 |
+
return_dense=False)
|
793 |
+
if (not homo_head) and (active_head_range is not None):
|
794 |
+
assert isinstance(active_head_range, tuple)
|
795 |
+
assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.'
|
796 |
+
h_start, h_end = active_head_range
|
797 |
+
block_sparse_pattern = block_sparse_pattern[h_start:h_end]
|
798 |
+
# print(block_sparse_pattern)
|
799 |
+
return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs)
|
800 |
+
|
801 |
+
|
802 |
+
def get_sparse_attn_op(
|
803 |
+
sparse_pattern: torch.tensor,
|
804 |
+
sparse_block_size: int=128,
|
805 |
+
kernel_block_size=128,
|
806 |
+
qkv_format='q,k,v',
|
807 |
+
**kwargs):
|
808 |
+
'''
|
809 |
+
Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime,
|
810 |
+
which is very inefficient (use python loops on CPU. PyTorch 1.13 supports CSR->CSC, may help.)
|
811 |
+
|
812 |
+
:param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`.
|
813 |
+
This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention
|
814 |
+
:param sparse_block_size: sparse block size. Default to 128
|
815 |
+
:param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size`
|
816 |
+
:param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported.
|
817 |
+
|
818 |
+
:param kwargs: keyward arguments passed to `_forward`
|
819 |
+
'''
|
820 |
+
# assert qkv_format in ('q,k,v', 'q, kv', 'qkv') # to save from running `concat` at forward/backward
|
821 |
+
|
822 |
+
assert qkv_format == 'q,k,v'
|
823 |
+
|
824 |
+
if kernel_block_size is None:
|
825 |
+
kernel_block_size = sparse_block_size
|
826 |
+
else:
|
827 |
+
assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}."
|
828 |
+
assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given"
|
829 |
+
|
830 |
+
|
831 |
+
# print(f'>> {sparse_pattern.shape=}')
|
832 |
+
# print(f'{sparse_pattern=}')
|
833 |
+
if sparse_block_size // kernel_block_size > 1:
|
834 |
+
_mul = sparse_block_size // kernel_block_size
|
835 |
+
# need to consider if block_m and block_n are different
|
836 |
+
sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul))
|
837 |
+
num_sparse_blocks = sparse_pattern.size(-1)
|
838 |
+
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
|
839 |
+
sparse_pattern *= block_causal_mask.type_as(sparse_pattern)
|
840 |
+
# print(f'>> after: {sparse_pattern.shape=}')
|
841 |
+
# print(f'{sparse_pattern=}')
|
842 |
+
|
843 |
+
BLOCK_N = kernel_block_size
|
844 |
+
NUM_BLOCK = sparse_pattern.size(-1)
|
845 |
+
MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK
|
846 |
+
|
847 |
+
grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern)
|
848 |
+
# sparse csc layout for backward
|
849 |
+
grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern)
|
850 |
+
|
851 |
+
|
852 |
+
# cache GPU backward layout. limit the size to avoid OOM as time goes.
|
853 |
+
# For inference, one only needs to cache one block as sequence length always increases
|
854 |
+
# Therefore, this cache needs to be reconstructed per every `block_size`-steps.
|
855 |
+
# For training/finetune, set to 8 to increase cache hit.
|
856 |
+
# Given an input, the block_len will be the same for all layers, so cache is very helpful.
|
857 |
+
|
858 |
+
max_cache_size = 1 if kwargs.get('inference', False) else 8
|
859 |
+
|
860 |
+
@lru_cache(maxsize=max_cache_size)
|
861 |
+
def get_backward_layout_by_block_len(block_len):
|
862 |
+
assert block_len <= NUM_BLOCK
|
863 |
+
if block_len == NUM_BLOCK:
|
864 |
+
return (grand_layout_ccol_indices, grand_layout_row_indices)
|
865 |
+
return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len])
|
866 |
+
|
867 |
+
# for debugging
|
868 |
+
# if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
869 |
+
# print(f'> {sparse_pattern.cpu().tolist()=}')
|
870 |
+
# print('----')
|
871 |
+
# print(f'> {grand_layout_crow_indices.cpu().tolist()=}\n{grand_layout_col_indices.cpu().tolist()=}')
|
872 |
+
|
873 |
+
|
874 |
+
# q, k, v separated
|
875 |
+
class _q_k_v_sparse_attention(torch.autograd.Function):
|
876 |
+
@staticmethod
|
877 |
+
def forward(ctx, q, k, v, sm_scale):
|
878 |
+
# assert q.shape[2] == 1 or q.shape[2] == k.shape[2]
|
879 |
+
# shape constraints
|
880 |
+
MIN_BLOCK_SIZE = 16
|
881 |
+
assert BLOCK_N >= MIN_BLOCK_SIZE
|
882 |
+
BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N # BLOCK_M has to be power of 2
|
883 |
+
|
884 |
+
# this following code only works for causal attention
|
885 |
+
K_BLOCKS = triton.cdiv(k.shape[2], kernel_block_size)
|
886 |
+
# Q_START_BLOCKS = K_BLOCKS - 1 if q.shape[2] == 1 else 0
|
887 |
+
Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N)
|
888 |
+
# print(Q_START_BLOCKS, K_BLOCKS)
|
889 |
+
|
890 |
+
layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1]
|
891 |
+
layout_col_indices = grand_layout_col_indices
|
892 |
+
# print(BLOCK_M, BLOCK_N, Q_START_BLOCKS, K_BLOCKS+1, layout_crow_indices, layout_col_indices)
|
893 |
+
|
894 |
+
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
|
895 |
+
**kwargs
|
896 |
+
)
|
897 |
+
@staticmethod
|
898 |
+
def backward(ctx, do):
|
899 |
+
q, k = ctx.saved_tensors[:2]
|
900 |
+
assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.'
|
901 |
+
# assume q, k have same length
|
902 |
+
block_len = triton.cdiv(do.shape[2], kernel_block_size)
|
903 |
+
backward_layout = get_backward_layout_by_block_len(block_len)
|
904 |
+
return _backward(ctx, do, *backward_layout)[:4]
|
905 |
+
|
906 |
+
|
907 |
+
def _q_k_v_sparse_attention_fn(*args):
|
908 |
+
return _q_k_v_sparse_attention.apply(*args)
|
909 |
+
|
910 |
+
_q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern
|
911 |
+
_q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices
|
912 |
+
_q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices
|
913 |
+
_q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices
|
914 |
+
_q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices
|
915 |
+
|
916 |
+
return _q_k_v_sparse_attention_fn
|
917 |
+
|
918 |
+
###########################################################
|
919 |
+
###########################################################
|
920 |
+
|
921 |
+
###########################################################
|
922 |
+
################ Inference Kernels ########################
|
923 |
+
###########################################################
|
924 |
+
|
925 |
+
def blocksparse_flash_attn_padded_fwd(
|
926 |
+
q, k, v, # (batch, tokens, n_heads, head_size)
|
927 |
+
sm_scale,
|
928 |
+
sparse_layout,
|
929 |
+
*,
|
930 |
+
left_paddings = None,
|
931 |
+
seqlens = None,
|
932 |
+
block_size = 64,
|
933 |
+
max_seqlen = None
|
934 |
+
):
|
935 |
+
'''
|
936 |
+
q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size)
|
937 |
+
left_paddings: (batch, ), number of left paddings for each sample.
|
938 |
+
seqlens: can be used to specify right padding. No need to specify if left_paddings is used.
|
939 |
+
'''
|
940 |
+
batches, q_len, n_heads, head_size = q.shape
|
941 |
+
_, k_len, n_kv_heads, _ = k.shape
|
942 |
+
|
943 |
+
|
944 |
+
assert q.dim() == k.dim() == v.dim() == 4
|
945 |
+
assert q.size(2) % k.size(2) == 0
|
946 |
+
assert q.size(0) == k.size(0) and q.size(3) == k.size(3)
|
947 |
+
assert k.shape == v.shape # TODO: allow diff head_size for k, v
|
948 |
+
assert q_len == 1 or q_len == k_len, \
|
949 |
+
f'q length can only 1 for decoding for same as k length for prefilling.'
|
950 |
+
|
951 |
+
q_k_ratio = q.size(2) // k.size(2)
|
952 |
+
|
953 |
+
if max_seqlen:
|
954 |
+
assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.'
|
955 |
+
|
956 |
+
# paddings always has zero output, a little slower than using empty
|
957 |
+
out = q.new_zeros(q.shape)
|
958 |
+
|
959 |
+
layout_crow_indices, layout_col_indices = sparse_layout
|
960 |
+
block_d = triton.next_power_of_2(head_size)
|
961 |
+
|
962 |
+
if left_paddings is not None:
|
963 |
+
assert left_paddings.shape == (batches,)
|
964 |
+
k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous()
|
965 |
+
else:
|
966 |
+
k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device)
|
967 |
+
|
968 |
+
if seqlens is not None:
|
969 |
+
k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts)
|
970 |
+
assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.'
|
971 |
+
else:
|
972 |
+
k_batch_ends = torch.zeros_like(k_batch_starts) + k_len
|
973 |
+
|
974 |
+
if q_len == 1:
|
975 |
+
q_batch_starts = torch.zeros_like(k_batch_starts)
|
976 |
+
q_batch_ends = q_batch_starts + 1
|
977 |
+
else:
|
978 |
+
q_batch_starts = k_batch_starts
|
979 |
+
q_batch_ends = k_batch_ends
|
980 |
+
|
981 |
+
# switch to use cpu to avoid too many kernel lauch when iterate over
|
982 |
+
q_lens = (q_batch_ends - q_batch_starts).cpu()
|
983 |
+
n_blocks = (q_lens + block_size - 1) // block_size
|
984 |
+
|
985 |
+
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
|
986 |
+
dtype=q_batch_starts.dtype,
|
987 |
+
device=q_batch_starts.device)
|
988 |
+
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
|
989 |
+
dtype=q_batch_starts.dtype,
|
990 |
+
device=q_batch_starts.device)
|
991 |
+
|
992 |
+
grid = (len(q_start_sids), n_heads)
|
993 |
+
|
994 |
+
_fwd_kernel_batch_inference[grid](
|
995 |
+
q, k, v, out,
|
996 |
+
sm_scale,
|
997 |
+
q_batch_starts,
|
998 |
+
q_batch_ends,
|
999 |
+
k_batch_starts,
|
1000 |
+
k_batch_ends,
|
1001 |
+
q_batch_ids,
|
1002 |
+
q_start_sids,
|
1003 |
+
|
1004 |
+
*q.stride(),
|
1005 |
+
*k.stride(),
|
1006 |
+
*v.stride(),
|
1007 |
+
*out.stride(),
|
1008 |
+
|
1009 |
+
layout_crow_indices,
|
1010 |
+
layout_col_indices,
|
1011 |
+
*layout_crow_indices.stride(),
|
1012 |
+
*layout_col_indices.stride(),
|
1013 |
+
|
1014 |
+
q_k_ratio,
|
1015 |
+
HAS_BATCH_DIM = True,
|
1016 |
+
D_HEAD = head_size,
|
1017 |
+
BLOCK_M = block_size,
|
1018 |
+
BLOCK_N = block_size,
|
1019 |
+
BLOCK_D = block_d,
|
1020 |
+
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
|
1021 |
+
EVEN_D = block_d == head_size,
|
1022 |
+
num_warps = 1 if q_len == 1 else 4,
|
1023 |
+
num_stages = 3
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
return out
|
1027 |
+
|
1028 |
+
|
1029 |
+
def blocksparse_flash_attn_varlen_fwd(
|
1030 |
+
q, k, v, # (#tokens, n_heads, head_size)
|
1031 |
+
cu_seqlens_k,
|
1032 |
+
cu_seqlens_q,
|
1033 |
+
sm_scale,
|
1034 |
+
sparse_layout,
|
1035 |
+
*,
|
1036 |
+
block_size=64,
|
1037 |
+
max_seqlen = None
|
1038 |
+
):
|
1039 |
+
# split q to blocks
|
1040 |
+
_, n_heads, head_size = q.shape
|
1041 |
+
batch_size = cu_seqlens_k.size(0) - 1
|
1042 |
+
|
1043 |
+
|
1044 |
+
# print(f'> {q.shape=}, {k.shape=}')
|
1045 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
1046 |
+
assert q.size(1) % k.size(1) == 0
|
1047 |
+
assert q.size(2) == k.size(2)
|
1048 |
+
assert k.shape == v.shape # TODO: allow diff head_size for k, v
|
1049 |
+
assert cu_seqlens_k.dim() == 1
|
1050 |
+
|
1051 |
+
q_k_ratio = q.size(1) // k.size(1)
|
1052 |
+
|
1053 |
+
if cu_seqlens_q is None:
|
1054 |
+
if q.size(0) == batch_size: # decoding only
|
1055 |
+
cu_seqlens_q = torch.arange(0, batch_size + 1,
|
1056 |
+
dtype=cu_seqlens_k.dtype,
|
1057 |
+
device=cu_seqlens_k.device)
|
1058 |
+
elif q.size(0) == k.size(0):
|
1059 |
+
cu_seqlens_q = cu_seqlens_k
|
1060 |
+
else:
|
1061 |
+
raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.')
|
1062 |
+
else:
|
1063 |
+
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
|
1064 |
+
|
1065 |
+
# switch to use cpu to avoid too many kernel lauch when iterate over
|
1066 |
+
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
|
1067 |
+
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
|
1068 |
+
|
1069 |
+
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \
|
1070 |
+
'length of q should either be 1 (decoding) or same as k (prefilling).'
|
1071 |
+
|
1072 |
+
if max_seqlen:
|
1073 |
+
assert k_lens.max() <= max_seqlen
|
1074 |
+
|
1075 |
+
n_blocks = (q_lens + block_size - 1) // block_size
|
1076 |
+
|
1077 |
+
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
|
1078 |
+
dtype=cu_seqlens_q.dtype,
|
1079 |
+
device=cu_seqlens_q.device)
|
1080 |
+
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
|
1081 |
+
dtype=cu_seqlens_q.dtype,
|
1082 |
+
device=cu_seqlens_q.device)
|
1083 |
+
|
1084 |
+
|
1085 |
+
out = q.new_empty(q.shape)
|
1086 |
+
cu_seqlens_q = cu_seqlens_q.contiguous()
|
1087 |
+
cu_seqlens_k = cu_seqlens_k.contiguous()
|
1088 |
+
|
1089 |
+
layout_crow_indices, layout_col_indices = sparse_layout
|
1090 |
+
block_d = triton.next_power_of_2(head_size)
|
1091 |
+
|
1092 |
+
decoding_only = (q_lens == 1).all()
|
1093 |
+
|
1094 |
+
grid = (len(q_start_sids), n_heads)
|
1095 |
+
|
1096 |
+
_fwd_kernel_batch_inference[grid](
|
1097 |
+
q, k, v, out,
|
1098 |
+
sm_scale,
|
1099 |
+
cu_seqlens_q[:-1],
|
1100 |
+
cu_seqlens_q[1:],
|
1101 |
+
cu_seqlens_k[:-1],
|
1102 |
+
cu_seqlens_k[1:],
|
1103 |
+
q_batch_ids,
|
1104 |
+
q_start_sids,
|
1105 |
+
|
1106 |
+
0, *q.stride(),
|
1107 |
+
0, *k.stride(),
|
1108 |
+
0, *v.stride(),
|
1109 |
+
0, *out.stride(),
|
1110 |
+
|
1111 |
+
layout_crow_indices,
|
1112 |
+
layout_col_indices,
|
1113 |
+
*layout_crow_indices.stride(),
|
1114 |
+
*layout_col_indices.stride(),
|
1115 |
+
|
1116 |
+
q_k_ratio,
|
1117 |
+
HAS_BATCH_DIM = False,
|
1118 |
+
D_HEAD = head_size,
|
1119 |
+
BLOCK_M = block_size,
|
1120 |
+
BLOCK_N = block_size,
|
1121 |
+
BLOCK_D = block_d,
|
1122 |
+
BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
|
1123 |
+
EVEN_D = block_d == head_size,
|
1124 |
+
num_warps = 1 if decoding_only else 4,
|
1125 |
+
num_stages = 3
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
return out
|
1129 |
+
|
1130 |
+
|
1131 |
+
@triton.jit
|
1132 |
+
def _fwd_kernel_inner(
|
1133 |
+
acc, l_i, m_i,
|
1134 |
+
q, Q,
|
1135 |
+
k_block_col_idx,
|
1136 |
+
layout_col_ptr,
|
1137 |
+
layout_col_stride_h, layout_col_stride_m,
|
1138 |
+
k_ptrs,
|
1139 |
+
v_ptrs,
|
1140 |
+
off_h, offs_m, offs_n, offs_d,
|
1141 |
+
stride_kt, stride_vt,
|
1142 |
+
sm_scale,
|
1143 |
+
k_seqlen,
|
1144 |
+
past_len,
|
1145 |
+
LAST_K_BLOCK: tl.constexpr,
|
1146 |
+
BLOCK_M_LOADING: tl.constexpr,
|
1147 |
+
BLOCK_N: tl.constexpr,
|
1148 |
+
D_HEAD: tl.constexpr,
|
1149 |
+
EVEN_D: tl.constexpr,
|
1150 |
+
M_LT_N: tl.constexpr
|
1151 |
+
):
|
1152 |
+
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32)
|
1153 |
+
start_n = k_block_id * BLOCK_N
|
1154 |
+
# -- compute qk ----
|
1155 |
+
if LAST_K_BLOCK:
|
1156 |
+
if EVEN_D:
|
1157 |
+
k = tl.load(k_ptrs + start_n * stride_kt,
|
1158 |
+
mask=offs_n[None, :] + start_n < k_seqlen)
|
1159 |
+
else:
|
1160 |
+
# mask = mask & (offs_d[:, ])
|
1161 |
+
k = tl.load(k_ptrs + start_n * stride_kt,
|
1162 |
+
mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD))
|
1163 |
+
else:
|
1164 |
+
if EVEN_D:
|
1165 |
+
k = tl.load(k_ptrs + start_n * stride_kt)
|
1166 |
+
else:
|
1167 |
+
k = tl.load(k_ptrs + start_n * stride_kt,
|
1168 |
+
mask=offs_d[:, None] < D_HEAD)
|
1169 |
+
|
1170 |
+
|
1171 |
+
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
1172 |
+
qk += tl.dot(q, k)
|
1173 |
+
|
1174 |
+
qk *= sm_scale
|
1175 |
+
|
1176 |
+
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
|
1177 |
+
if LAST_K_BLOCK | M_LT_N:
|
1178 |
+
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf'))
|
1179 |
+
|
1180 |
+
# -- compute m_ij, p, l_ij
|
1181 |
+
m_ij = tl.max(qk, 1)
|
1182 |
+
p = tl.exp(qk - m_ij[:, None])
|
1183 |
+
|
1184 |
+
l_ij = tl.sum(p, 1)
|
1185 |
+
# -- update m_i and l_i
|
1186 |
+
m_i_new = tl.maximum(m_i, m_ij)
|
1187 |
+
alpha = tl.exp(m_i - m_i_new)
|
1188 |
+
beta = tl.exp(m_ij - m_i_new)
|
1189 |
+
l_i_new = alpha * l_i + beta * l_ij
|
1190 |
+
# -- update output accumulator --
|
1191 |
+
# scale p
|
1192 |
+
p_scale = beta / l_i_new
|
1193 |
+
p = p * p_scale[:, None]
|
1194 |
+
# scale acc
|
1195 |
+
acc_scale = l_i / l_i_new * alpha
|
1196 |
+
acc = acc * acc_scale[:, None]
|
1197 |
+
|
1198 |
+
p = p.to(Q.dtype.element_ty)
|
1199 |
+
# update acc
|
1200 |
+
if LAST_K_BLOCK:
|
1201 |
+
if EVEN_D:
|
1202 |
+
v = tl.load(v_ptrs + start_n * stride_vt,
|
1203 |
+
mask=offs_n[:, None] + start_n < k_seqlen)
|
1204 |
+
else:
|
1205 |
+
v = tl.load(v_ptrs + start_n * stride_vt,
|
1206 |
+
mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD))
|
1207 |
+
else:
|
1208 |
+
if EVEN_D:
|
1209 |
+
v = tl.load(v_ptrs + start_n * stride_vt)
|
1210 |
+
else:
|
1211 |
+
v = tl.load(v_ptrs + start_n * stride_vt,
|
1212 |
+
mask=offs_d[None, :] < D_HEAD)
|
1213 |
+
|
1214 |
+
acc += tl.dot(p, v)
|
1215 |
+
# update m_i and l_i
|
1216 |
+
l_i = l_i_new
|
1217 |
+
m_i = m_i_new
|
1218 |
+
return acc, l_i, m_i
|
1219 |
+
|
1220 |
+
|
1221 |
+
@triton.heuristics(
|
1222 |
+
{
|
1223 |
+
'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'],
|
1224 |
+
}
|
1225 |
+
)
|
1226 |
+
@triton.jit
|
1227 |
+
def _fwd_kernel_batch_inference(
|
1228 |
+
Q, K, V, Out,
|
1229 |
+
|
1230 |
+
sm_scale,
|
1231 |
+
q_batch_starts,
|
1232 |
+
q_batch_ends,
|
1233 |
+
k_batch_starts,
|
1234 |
+
k_batch_ends,
|
1235 |
+
q_batch_ids,
|
1236 |
+
q_start_sids,
|
1237 |
+
|
1238 |
+
stride_qb, stride_qt, stride_qh, stride_qd,
|
1239 |
+
stride_kb, stride_kt, stride_kh, stride_kd,
|
1240 |
+
stride_vb, stride_vt, stride_vh, stride_vd,
|
1241 |
+
stride_ob, stride_ot, stride_oh, stride_od,
|
1242 |
+
|
1243 |
+
layout_crow_ptr,
|
1244 |
+
layout_col_ptr,
|
1245 |
+
layout_crow_stride_h, layout_crow_stride_m,
|
1246 |
+
layout_col_stride_h, layout_col_stride_m,
|
1247 |
+
|
1248 |
+
q_k_ratio,
|
1249 |
+
|
1250 |
+
HAS_BATCH_DIM: tl.constexpr,
|
1251 |
+
D_HEAD: tl.constexpr,
|
1252 |
+
BLOCK_M: tl.constexpr,
|
1253 |
+
BLOCK_N: tl.constexpr,
|
1254 |
+
BLOCK_D: tl.constexpr,
|
1255 |
+
BLOCK_M_LOADING: tl.constexpr,
|
1256 |
+
EVEN_D: tl.constexpr,
|
1257 |
+
M_LT_N: tl.constexpr
|
1258 |
+
):
|
1259 |
+
'''
|
1260 |
+
NOTATION:
|
1261 |
+
pid: position id
|
1262 |
+
sid: storage id
|
1263 |
+
sbid: storage block id
|
1264 |
+
pbid: position block id
|
1265 |
+
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
1266 |
+
|
1267 |
+
q and blocks in KV needs to be contiguous
|
1268 |
+
|
1269 |
+
Arguments:
|
1270 |
+
kv_seq_lens: for compute past_len
|
1271 |
+
kv_storage_offsets: similar to block_tables in vllm, except it is dynamic.
|
1272 |
+
TODO: fix this
|
1273 |
+
|
1274 |
+
TODO:
|
1275 |
+
Optimize grouped-attn
|
1276 |
+
|
1277 |
+
CUDA graph support issue
|
1278 |
+
1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...)
|
1279 |
+
since we mix prompt and decoing phase here, it can be more complex.
|
1280 |
+
need to set up diff cuda-graph for diff (off_zm, off_z)
|
1281 |
+
|
1282 |
+
# indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding
|
1283 |
+
therefore, cu_seqlens_q, kv_seq_lens
|
1284 |
+
|
1285 |
+
'''
|
1286 |
+
off_zm = tl.program_id(0)
|
1287 |
+
off_h = tl.program_id(1)
|
1288 |
+
|
1289 |
+
off_h_for_kv = off_h // q_k_ratio
|
1290 |
+
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
|
1291 |
+
q_start_sid = tl.load(q_start_sids + off_zm)
|
1292 |
+
start_m = q_start_sid // BLOCK_M
|
1293 |
+
|
1294 |
+
if HAS_BATCH_DIM:
|
1295 |
+
Q += off_z * stride_qb
|
1296 |
+
K += off_z * stride_kb
|
1297 |
+
V += off_z * stride_vb
|
1298 |
+
Out += off_z * stride_ob
|
1299 |
+
|
1300 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
1301 |
+
offs_n = tl.arange(0, BLOCK_N)
|
1302 |
+
offs_d = tl.arange(0, BLOCK_D)
|
1303 |
+
|
1304 |
+
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
|
1305 |
+
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
|
1306 |
+
|
1307 |
+
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
|
1308 |
+
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
|
1309 |
+
|
1310 |
+
past_len = k_seqlen - q_seqlen
|
1311 |
+
|
1312 |
+
Q += q_cu_start * stride_qt + off_h * stride_qh
|
1313 |
+
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
1314 |
+
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
1315 |
+
Out += q_cu_start * stride_ot + off_h * stride_oh
|
1316 |
+
|
1317 |
+
q_pbid = (past_len + q_start_sid) // BLOCK_M
|
1318 |
+
|
1319 |
+
if EVEN_D:
|
1320 |
+
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
1321 |
+
mask=offs_m[:, None] < q_seqlen)
|
1322 |
+
else:
|
1323 |
+
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
1324 |
+
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
1325 |
+
other=0)
|
1326 |
+
|
1327 |
+
sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m
|
1328 |
+
|
1329 |
+
# TODO: load at once, supported in new Triton
|
1330 |
+
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
1331 |
+
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
1332 |
+
|
1333 |
+
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf')
|
1334 |
+
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
1335 |
+
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
1336 |
+
|
1337 |
+
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
|
1338 |
+
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
|
1339 |
+
|
1340 |
+
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
1341 |
+
acc, l_i, m_i = _fwd_kernel_inner(
|
1342 |
+
acc, l_i, m_i,
|
1343 |
+
q, Q,
|
1344 |
+
k_block_col_idx,
|
1345 |
+
layout_col_ptr,
|
1346 |
+
layout_col_stride_h, layout_col_stride_m,
|
1347 |
+
k_ptrs,
|
1348 |
+
v_ptrs,
|
1349 |
+
off_h, offs_m, offs_n, offs_d,
|
1350 |
+
stride_kt, stride_vt,
|
1351 |
+
sm_scale,
|
1352 |
+
k_seqlen,
|
1353 |
+
past_len,
|
1354 |
+
False,
|
1355 |
+
BLOCK_M_LOADING,
|
1356 |
+
BLOCK_N,
|
1357 |
+
D_HEAD,
|
1358 |
+
EVEN_D,
|
1359 |
+
M_LT_N
|
1360 |
+
)
|
1361 |
+
|
1362 |
+
acc, l_i, m_i = _fwd_kernel_inner(
|
1363 |
+
acc, l_i, m_i,
|
1364 |
+
q, Q,
|
1365 |
+
k_block_end - 1,
|
1366 |
+
layout_col_ptr,
|
1367 |
+
layout_col_stride_h, layout_col_stride_m,
|
1368 |
+
k_ptrs,
|
1369 |
+
v_ptrs,
|
1370 |
+
off_h, offs_m, offs_n, offs_d,
|
1371 |
+
stride_kt, stride_vt,
|
1372 |
+
sm_scale,
|
1373 |
+
k_seqlen,
|
1374 |
+
past_len,
|
1375 |
+
True,
|
1376 |
+
BLOCK_M_LOADING,
|
1377 |
+
BLOCK_N,
|
1378 |
+
D_HEAD,
|
1379 |
+
EVEN_D,
|
1380 |
+
M_LT_N
|
1381 |
+
)
|
1382 |
+
|
1383 |
+
# write output
|
1384 |
+
if EVEN_D:
|
1385 |
+
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
|
1386 |
+
mask=offs_m[:, None] < q_seqlen)
|
1387 |
+
else:
|
1388 |
+
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
|
1389 |
+
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD))
|
1390 |
+
|
1391 |
+
|
1392 |
+
###########################################################
|
1393 |
+
###########################################################
|
1394 |
+
|
1395 |
+
###########################################################
|
1396 |
+
################## Testing Utilities ######################
|
1397 |
+
###########################################################
|
1398 |
+
|
1399 |
+
|
1400 |
+
def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None):
|
1401 |
+
'''
|
1402 |
+
q, k, v: shape=(batch, n_heads, seq, dim)
|
1403 |
+
'''
|
1404 |
+
# for verification
|
1405 |
+
if sm_scale is None:
|
1406 |
+
sm_scale = math.sqrt(float(q.size(-1)))
|
1407 |
+
|
1408 |
+
if block_attn_mask is not None:
|
1409 |
+
assert attn_mask is None
|
1410 |
+
outs = []
|
1411 |
+
for s in range(0, q.size(2), block_size):
|
1412 |
+
e = min(s + block_size, q.size(2))
|
1413 |
+
q_block = q[:, :, s:e]
|
1414 |
+
attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale
|
1415 |
+
mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)]
|
1416 |
+
mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device))
|
1417 |
+
mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0)
|
1418 |
+
attn = attn.masked_fill((1 - mask).bool(), float('-inf'))
|
1419 |
+
attn = attn.softmax(-1)
|
1420 |
+
out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e])
|
1421 |
+
outs.append(out)
|
1422 |
+
torch_output = torch.cat(outs, dim=2)
|
1423 |
+
else:
|
1424 |
+
attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale
|
1425 |
+
# import ipdb; ipdb.set_trace()
|
1426 |
+
if attn_mask is not None:
|
1427 |
+
attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf'))
|
1428 |
+
# print(f'> torch attn: {attn.exp().sum(-1)=}')
|
1429 |
+
|
1430 |
+
attn = attn.softmax(-1)
|
1431 |
+
if do is not None:
|
1432 |
+
dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do)
|
1433 |
+
print(f'> torch_attn computed dv: {dv=}')
|
1434 |
+
torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v)
|
1435 |
+
return torch_output
|
1436 |
+
|
1437 |
+
###########################################################
|
1438 |
+
###########################################################
|
1439 |
+
|
1440 |
+
###########################################################
|
1441 |
+
#################### Unit Tests ###########################
|
1442 |
+
###########################################################
|
1443 |
+
|
1444 |
+
|
1445 |
+
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)])
|
1446 |
+
def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True,
|
1447 |
+
sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None):
|
1448 |
+
Q_LEN = Q_LEN or N_CTX
|
1449 |
+
torch.manual_seed(20)
|
1450 |
+
q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
|
1451 |
+
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
|
1452 |
+
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
|
1453 |
+
|
1454 |
+
if sm_scale is None:
|
1455 |
+
sm_scale = 1. / math.sqrt(D_HEAD)
|
1456 |
+
|
1457 |
+
# for debugging
|
1458 |
+
# print(f'>> {q.shape=}, {k.shape=}, {v.shape=}, {homo_head=}, {kernel_block_size=}, {sparse_block_size=}, {local_blocks=}, {vert_stride=}')
|
1459 |
+
sm_scale = 0.0078125
|
1460 |
+
if backward:
|
1461 |
+
q.requires_grad_(), k.requires_grad_(), v.requires_grad_()
|
1462 |
+
|
1463 |
+
# qkv = torch.empty((Z, N_CTX, 3*H*D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)
|
1464 |
+
# q = qkv[..., :H*D_HEAD]
|
1465 |
+
# k = qkv[..., H*D_HEAD:2*H*D_HEAD]
|
1466 |
+
# v = qkv[..., 2*H*D_HEAD:]
|
1467 |
+
# q = q.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
|
1468 |
+
# k = k.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
|
1469 |
+
# v = v.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
|
1470 |
+
|
1471 |
+
# if Q_LEN and Q_LEN < N_CTX:
|
1472 |
+
# q = q[:, :, -Q_LEN:] # .contiguous()
|
1473 |
+
|
1474 |
+
# q = q.requires_grad_()
|
1475 |
+
# k = k.requires_grad_()
|
1476 |
+
# v = v.requires_grad_()
|
1477 |
+
|
1478 |
+
dout = torch.randn_like(q).contiguous()
|
1479 |
+
|
1480 |
+
# dout = torch.eye(N_CTX)[:, :D_HEAD][None, None].expand_as(q).type_as(q).contiguous()
|
1481 |
+
# print(dout)
|
1482 |
+
|
1483 |
+
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size,
|
1484 |
+
local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True)
|
1485 |
+
|
1486 |
+
if sparse_attention_fn is None:
|
1487 |
+
sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX,
|
1488 |
+
sparse_block_size=sparse_block_size,
|
1489 |
+
local_blocks=local_blocks,
|
1490 |
+
vert_stride=vert_stride,
|
1491 |
+
homo_head=homo_head,
|
1492 |
+
device=q.device,
|
1493 |
+
dtype=q.dtype,
|
1494 |
+
kernel_block_size=kernel_block_size)
|
1495 |
+
# reference implementation
|
1496 |
+
ref_out = torch_attention(q, k, v, mask_dense, sm_scale)
|
1497 |
+
|
1498 |
+
# lengths = torch.full((Z,), fill_value=N_CTX, device='cuda')
|
1499 |
+
# cu_seqlens = torch.zeros((Z + 1,), device='cuda', dtype=torch.int32)
|
1500 |
+
# cu_seqlens[1:] = lengths.cumsum(0)
|
1501 |
+
# # qkv = torch.randn((Z * N_CTX, 3, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1502 |
+
|
1503 |
+
# qkv_list = list(map(lambda x: x.permute(0, 2, 1, 3).contiguous().view(Z * N_CTX, 1, H, D_HEAD), [q, k, v]))
|
1504 |
+
# qkv = torch.cat(qkv_list, dim=1)
|
1505 |
+
# ref_out0 = flash_attn_func(qkv, cu_seqlens, dropout_p=0, max_s=N_CTX, softmax_scale=sm_scale, causal=True)
|
1506 |
+
# ref_out = ref_out0.view(Z, N_CTX, H, D_HEAD).permute(0, 2, 1, 3).contiguous()
|
1507 |
+
|
1508 |
+
|
1509 |
+
if backward:
|
1510 |
+
ref_out.backward(dout)
|
1511 |
+
ref_dv, v.grad = v.grad.clone(), None
|
1512 |
+
ref_dk, k.grad = k.grad.clone(), None
|
1513 |
+
ref_dq, q.grad = q.grad.clone(), None
|
1514 |
+
|
1515 |
+
tri_out = sparse_attention_fn(q, k, v, sm_scale)
|
1516 |
+
|
1517 |
+
decimal = 1 if dtype == torch.bfloat16 else 2
|
1518 |
+
assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}'
|
1519 |
+
|
1520 |
+
if backward:
|
1521 |
+
tri_out.backward(dout)
|
1522 |
+
tri_dv, v.grad = v.grad.clone(), None
|
1523 |
+
tri_dk, k.grad = k.grad.clone(), None
|
1524 |
+
tri_dq, q.grad = q.grad.clone(), None
|
1525 |
+
|
1526 |
+
if backward:
|
1527 |
+
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
|
1528 |
+
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
|
1529 |
+
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
|
1530 |
+
|
1531 |
+
print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}')
|
1532 |
+
|
1533 |
+
###########################################################
|
1534 |
+
|
1535 |
+
if __name__ == '__main__':
|
1536 |
+
|
1537 |
+
GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip()
|
1538 |
+
# print(GPU_TYPE)
|
1539 |
+
support_backward = True # 'A100' in GPU_TYPE. Wasn't supportted in consumer A1000.
|
1540 |
+
|
1541 |
+
###############
|
1542 |
+
# benchmarking
|
1543 |
+
|
1544 |
+
HAS_DENSE_TRITON_FLASH = False
|
1545 |
+
# try:
|
1546 |
+
# from triton.ops.flash_attention import attention as triton_attention
|
1547 |
+
# HAS_DENSE_TRITON_FLASH = True
|
1548 |
+
# except:
|
1549 |
+
# HAS_DENSE_TRITON_FLASH = False
|
1550 |
+
# print('> cannot import Trition flash attn')
|
1551 |
+
|
1552 |
+
try:
|
1553 |
+
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func
|
1554 |
+
HAS_FLASH = True
|
1555 |
+
except BaseException:
|
1556 |
+
HAS_FLASH = False
|
1557 |
+
print('> cannot import flash_attn')
|
1558 |
+
|
1559 |
+
|
1560 |
+
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
1561 |
+
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128 # 6.7B model, with 4k len
|
1562 |
+
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 16, 4096, 128 # 204m model
|
1563 |
+
|
1564 |
+
BLOCK_SIZE = 64
|
1565 |
+
LOCAl_BLOCKS = 8 # 4
|
1566 |
+
VERT_STRIDE = 1 # 16 # 8
|
1567 |
+
HOMO_HEAD = False
|
1568 |
+
sparse_type = 'home' if HOMO_HEAD else 'hetero'
|
1569 |
+
dtype = torch.bfloat16
|
1570 |
+
|
1571 |
+
|
1572 |
+
modes = ['fwd', 'bwd'] if support_backward else ['fwd']
|
1573 |
+
|
1574 |
+
configs = [triton.testing.Benchmark(
|
1575 |
+
x_names=['SEQ_LEN'],
|
1576 |
+
x_vals=[2**i for i in range(8, 16)],
|
1577 |
+
line_arg='provider',
|
1578 |
+
line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'],
|
1579 |
+
line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else []) + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'],
|
1580 |
+
styles=[('red', '-'), ('blue', '-'), ('green', '-')],
|
1581 |
+
ylabel='ms',
|
1582 |
+
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}',
|
1583 |
+
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode}
|
1584 |
+
) for mode in modes]
|
1585 |
+
|
1586 |
+
|
1587 |
+
@triton.testing.perf_report(configs)
|
1588 |
+
def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None):
|
1589 |
+
assert mode in ['fwd', 'bwd']
|
1590 |
+
warmup = 25
|
1591 |
+
rep = 100
|
1592 |
+
N_CTX = SEQ_LEN
|
1593 |
+
if provider == 'triton':
|
1594 |
+
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1595 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1596 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1597 |
+
sm_scale = 1.3
|
1598 |
+
fn = lambda: triton_attention(q, k, v, sm_scale)
|
1599 |
+
if mode == 'bwd':
|
1600 |
+
o = fn()
|
1601 |
+
do = torch.randn_like(o)
|
1602 |
+
fn = lambda: o.backward(do, retain_graph=True)
|
1603 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1604 |
+
return ms
|
1605 |
+
if provider == 'triton_sparse':
|
1606 |
+
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1607 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1608 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1609 |
+
sm_scale = 1.3
|
1610 |
+
# q_pos = torch.arange(N_CTX // BLOCK, device='cuda')[:, None]
|
1611 |
+
# k_pos = torch.arange(N_CTX // BLOCK, device='cuda')[None]
|
1612 |
+
# local_blocks = 4 # num_block per attn, block_size is tied to BLOCK
|
1613 |
+
# vert_stride =N_CTX + 1 # 4
|
1614 |
+
# mask_vert_strided = torch.arange(N_CTX // BLOCK, device='cuda') % vert_stride == vert_stride - 1
|
1615 |
+
# mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).type_as(q)
|
1616 |
+
# mask = mask_dense.to_sparse_csr()
|
1617 |
+
# mask_csr, _ = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK, local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=HOMO_HEAD)
|
1618 |
+
|
1619 |
+
if sparse_attention_fn is None:
|
1620 |
+
# sparse_attention_fn = sparse_attention
|
1621 |
+
sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN,
|
1622 |
+
local_blocks=LOCAl_BLOCKS,
|
1623 |
+
vert_stride=VERT_STRIDE,
|
1624 |
+
homo_head=HOMO_HEAD,
|
1625 |
+
sparse_block_size=BLOCK_SIZE,
|
1626 |
+
kernel_block_size=BLOCK_SIZE,
|
1627 |
+
device=q.device)
|
1628 |
+
# sparse_attention_fn = sparse_attention_factory(128, 128, num_warps=8)
|
1629 |
+
|
1630 |
+
# fn = lambda: sparse_attention_fn(q, k, v, mask_csr[0], mask_csr[1], sm_scale)
|
1631 |
+
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
|
1632 |
+
if mode == 'bwd':
|
1633 |
+
o = fn()
|
1634 |
+
do = torch.randn_like(o)
|
1635 |
+
fn = lambda: o.backward(do, retain_graph=True)
|
1636 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1637 |
+
return ms
|
1638 |
+
if provider == 'flash':
|
1639 |
+
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
1640 |
+
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
1641 |
+
cu_seqlens[1:] = lengths.cumsum(0)
|
1642 |
+
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
1643 |
+
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
1644 |
+
if mode == 'bwd':
|
1645 |
+
o = fn()
|
1646 |
+
do = torch.randn_like(o)
|
1647 |
+
fn = lambda: o.backward(do, retain_graph=True)
|
1648 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1649 |
+
return ms
|
1650 |
+
|
1651 |
+
# if provider == 'torch':
|
1652 |
+
# q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1653 |
+
# k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1654 |
+
# v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1655 |
+
# sm_scale = 1.3
|
1656 |
+
# causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(q)
|
1657 |
+
# fn = lambda: torch_attention(q, k, v, causal_mask, sm_scale)
|
1658 |
+
# ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
1659 |
+
# return ms
|
1660 |
+
|
1661 |
+
|
1662 |
+
BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1 # 6.7B model, with 4k len
|
1663 |
+
|
1664 |
+
BLOCK_SIZE = 64
|
1665 |
+
LOCAl_BLOCKS = 8 # 4
|
1666 |
+
VERT_STRIDE = 16 # 8
|
1667 |
+
HOMO_HEAD = False
|
1668 |
+
sparse_type = 'home' if HOMO_HEAD else 'hetero'
|
1669 |
+
dtype = torch.bfloat16
|
1670 |
+
MAX_N_CTX = 8192
|
1671 |
+
|
1672 |
+
configs = [triton.testing.Benchmark(
|
1673 |
+
x_names=['PAST_LEN'],
|
1674 |
+
x_vals=[2**i - 1 for i in range(8, 14)],
|
1675 |
+
line_arg='provider',
|
1676 |
+
line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'],
|
1677 |
+
line_names=['Torch'] + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'],
|
1678 |
+
styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')],
|
1679 |
+
ylabel='ms',
|
1680 |
+
plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}',
|
1681 |
+
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode}
|
1682 |
+
) for mode in ['fwd']]
|
1683 |
+
@triton.testing.perf_report(configs)
|
1684 |
+
def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'):
|
1685 |
+
assert mode in ['fwd']
|
1686 |
+
warmup = 25
|
1687 |
+
rep = 100
|
1688 |
+
N_CTX = PAST_LEN + Q_LEN
|
1689 |
+
if provider == 'torch':
|
1690 |
+
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1691 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1692 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1693 |
+
sm_scale = 1.3
|
1694 |
+
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE,
|
1695 |
+
local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True)
|
1696 |
+
|
1697 |
+
fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048)
|
1698 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1699 |
+
return ms
|
1700 |
+
if provider == 'triton_sparse':
|
1701 |
+
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1702 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1703 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1704 |
+
sm_scale = 1.3
|
1705 |
+
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
|
1706 |
+
local_blocks=LOCAl_BLOCKS,
|
1707 |
+
vert_stride=VERT_STRIDE,
|
1708 |
+
homo_head=HOMO_HEAD,
|
1709 |
+
sparse_block_size=BLOCK_SIZE,
|
1710 |
+
kernel_block_size=BLOCK_SIZE,
|
1711 |
+
device=q.device,
|
1712 |
+
inference=True)
|
1713 |
+
|
1714 |
+
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
|
1715 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1716 |
+
return ms
|
1717 |
+
if provider == 'triton_dense':
|
1718 |
+
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1719 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1720 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1721 |
+
sm_scale = 1.3
|
1722 |
+
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
|
1723 |
+
local_blocks=1,
|
1724 |
+
vert_stride=1,
|
1725 |
+
homo_head=True,
|
1726 |
+
sparse_block_size=BLOCK_SIZE,
|
1727 |
+
kernel_block_size=BLOCK_SIZE,
|
1728 |
+
device=q.device,
|
1729 |
+
inference=True)
|
1730 |
+
|
1731 |
+
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
|
1732 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1733 |
+
return ms
|
1734 |
+
if provider == 'flash':
|
1735 |
+
assert Q_LEN == 1
|
1736 |
+
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
1737 |
+
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
1738 |
+
cu_seqlens[1:] = lengths.cumsum(0)
|
1739 |
+
cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32)
|
1740 |
+
|
1741 |
+
# (total_q, nheads, headdim),
|
1742 |
+
q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1743 |
+
k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1744 |
+
v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1745 |
+
|
1746 |
+
fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False)
|
1747 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1748 |
+
return ms
|
1749 |
+
|
1750 |
+
|
1751 |
+
test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
|
1752 |
+
# bench_flash_attention.run(save_path='.', print_data=True)
|
1753 |
+
|
1754 |
+
bench_flash_attention_inference.run(save_path='.', print_data=True)
|
1755 |
+
exit()
|
1756 |
+
# head_dim=64
|
1757 |
+
test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64,
|
1758 |
+
dtype=torch.bfloat16, homo_head=False, backward=support_backward)
|
1759 |
+
# uneven length, bf16
|
1760 |
+
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128,
|
1761 |
+
kernel_block_size=64, local_blocks=8, vert_stride=8)
|
1762 |
+
test_op(3, 2, 2047, 128, homo_head=False, backward=False)
|
1763 |
+
|
1764 |
+
# diff kernel/sparse block size
|
1765 |
+
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64)
|
1766 |
+
# inference
|
1767 |
+
# test_op(1, 4, 512 + 256, 128, Q_LEN=1, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
|
1768 |
+
|
1769 |
+
# dense flash attn
|
1770 |
+
test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False,
|
1771 |
+
backward=support_backward, local_blocks=1, vert_stride=1)
|
1772 |
+
|
1773 |
+
# fp16
|
1774 |
+
test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
|
1775 |
+
|
1776 |
+
# longer sequence
|
1777 |
+
test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward)
|
1778 |
+
test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
|
1779 |
+
|
1780 |
+
# homo head
|
1781 |
+
test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False)
|
1782 |
+
test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward)
|
1783 |
+
|
1784 |
+
# sparse_attention_fn = sparse_attention_factory(16, 128, num_warps=1, INFERENCE=True)
|
1785 |
+
# test_op(8, 1, 2047, 128, 1, backward=False, sparse_attention_fn=None)
|
1786 |
+
# test_op_inference(3, 2, 2048, 128, 2048)
|
1787 |
+
# test_op_inference(3, 2, 2047, 64, 2047)
|
1788 |
+
# test_op_inference(3, 2, 256, 64, 128)
|
1789 |
+
# test_op_inference(3, 2, 2048, 64, 1)
|
1790 |
+
|
1791 |
+
bench_flash_attention.run(save_path='.', print_data=True)
|
1792 |
+
# bench_flash_attention_inference.run(save_path='.', print_data=True)
|
1793 |
+
|
1794 |
+
# ========================
|
1795 |
+
# Some Benchmark Results #
|
1796 |
+
# ========================
|
1797 |
+
|
1798 |
+
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-fwd
|
1799 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1800 |
+
# 0 256.0 0.057184 0.069646 0.052567
|
1801 |
+
# 1 512.0 0.131688 0.187658 0.110212
|
1802 |
+
# 2 1024.0 0.391844 0.524990 0.247875
|
1803 |
+
# 3 2048.0 1.305190 1.456685 0.596506
|
1804 |
+
# 4 4096.0 4.623019 4.968653 1.600277
|
1805 |
+
# 5 8192.0 17.513062 18.332262 4.802458
|
1806 |
+
# 6 16384.0 68.453377 70.337540 16.052908
|
1807 |
+
# 7 32768.0 270.655487 276.020233 57.938946
|
1808 |
+
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-bwd (num_warp=8):
|
1809 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1810 |
+
# 0 256.0 0.190120 0.150313 0.181451
|
1811 |
+
# 1 512.0 0.406348 0.391767 0.391177
|
1812 |
+
# 2 1024.0 1.029704 1.182967 0.885741
|
1813 |
+
# 3 2048.0 2.985456 3.843399 2.040469
|
1814 |
+
# 4 4096.0 9.808897 13.073701 5.069609
|
1815 |
+
# 5 8192.0 34.995201 47.863808 13.948782
|
1816 |
+
# 6 16384.0 132.740097 182.579193 42.816513
|
1817 |
+
# 7 32768.0 542.223389 714.820618 147.053574
|
1818 |
+
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
|
1819 |
+
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
|
1820 |
+
# 0 256.0 0.050949 0.032357 0.107513
|
1821 |
+
# 1 512.0 0.073624 0.050651 0.199086
|
1822 |
+
# 2 1024.0 0.107472 0.080379 0.245445
|
1823 |
+
# 3 2048.0 0.178423 0.129448 0.338259
|
1824 |
+
# 4 4096.0 0.327647 0.223106 0.517048
|
1825 |
+
# 5 8192.0 0.588423 0.411263 0.884606
|
1826 |
+
# 6 16384.0 1.098898 0.798941 1.611809
|
1827 |
+
# 7 32768.0 2.094537 1.594726 3.044160
|
1828 |
+
|
1829 |
+
|
1830 |
+
# 6.7B
|
1831 |
+
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-fwd:
|
1832 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1833 |
+
# 0 256.0 0.069208 0.082156 0.065097
|
1834 |
+
# 1 512.0 0.138271 0.201393 0.144467
|
1835 |
+
# 2 1024.0 0.391521 0.624614 0.322382
|
1836 |
+
# 3 2048.0 1.268443 2.406325 0.784367
|
1837 |
+
# 4 4096.0 4.455703 9.139097 2.100856
|
1838 |
+
# 5 8192.0 16.764315 35.289600 6.328320
|
1839 |
+
# 6 16384.0 65.221634 138.401794 21.069057
|
1840 |
+
# 7 32768.0 257.251343 548.085754 76.111870
|
1841 |
+
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-bwd:
|
1842 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1843 |
+
# 0 256.0 0.297118 0.266469 0.255255
|
1844 |
+
# 1 512.0 0.672826 0.613685 0.552954
|
1845 |
+
# 2 1024.0 1.718434 1.705066 1.251953
|
1846 |
+
# 3 2048.0 4.936755 5.403875 2.927895
|
1847 |
+
# 4 4096.0 15.911594 18.959362 7.436288
|
1848 |
+
# 5 8192.0 55.357441 70.808578 21.140224
|
1849 |
+
# 6 16384.0 208.188416 273.617920 68.018173
|
1850 |
+
# 7 32768.0 806.037476 1081.453613 218.720261
|
1851 |
+
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
|
1852 |
+
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
|
1853 |
+
# 0 256.0 0.050151 0.032337 0.107593
|
1854 |
+
# 1 512.0 0.073409 0.051737 0.200200
|
1855 |
+
# 2 1024.0 0.107533 0.082099 0.247067
|
1856 |
+
# 3 2048.0 0.177259 0.128891 0.338510
|
1857 |
+
# 4 4096.0 0.325866 0.223621 0.524842
|
1858 |
+
# 5 8192.0 0.586926 0.408913 0.885490
|
1859 |
+
# 6 16384.0 1.100834 0.793277 1.612271
|
1860 |
+
# 7 32768.0 2.098851 1.595831 3.064544
|
1861 |
+
|
1862 |
+
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-fwd:
|
1863 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1864 |
+
# 0 256.0 0.066673 0.082037 0.065085
|
1865 |
+
# 1 512.0 0.137379 0.201880 0.143473
|
1866 |
+
# 2 1024.0 0.390675 0.624234 0.312046
|
1867 |
+
# 3 2048.0 1.267739 2.406950 0.696045
|
1868 |
+
# 4 4096.0 4.445138 9.136333 1.665788
|
1869 |
+
# 5 8192.0 16.768614 35.265533 4.380486
|
1870 |
+
# 6 16384.0 65.235970 138.393600 12.997633
|
1871 |
+
# 7 32768.0 257.317902 550.442993 42.821121
|
1872 |
+
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-bwd:
|
1873 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1874 |
+
# 0 256.0 0.296461 0.266581 0.254022
|
1875 |
+
# 1 512.0 0.671427 0.613643 0.551283
|
1876 |
+
# 2 1024.0 1.719918 1.704295 1.229982
|
1877 |
+
# 3 2048.0 4.945305 5.403364 2.721906
|
1878 |
+
# 4 4096.0 15.934293 18.960999 6.259371
|
1879 |
+
# 5 8192.0 55.406593 70.832130 15.676929
|
1880 |
+
# 6 16384.0 208.750595 275.004425 44.837891
|
1881 |
+
# 7 32768.0 808.057861 1080.647705 141.856766
|
1882 |
+
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert8-hetero:
|
1883 |
+
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
|
1884 |
+
# 0 256.0 0.050739 0.032886 0.107837
|
1885 |
+
# 1 512.0 0.073507 0.051996 0.200293
|
1886 |
+
# 2 1024.0 0.106394 0.080679 0.240610
|
1887 |
+
# 3 2048.0 0.177659 0.127660 0.287625
|
1888 |
+
# 4 4096.0 0.326326 0.226971 0.377500
|
1889 |
+
# 5 8192.0 0.586339 0.407367 0.559266
|
1890 |
+
# 6 16384.0 1.102279 0.786221 0.920976
|
1891 |
+
# 7 32768.0 2.097370 1.545090 1.644288
|
1892 |
+
|
1893 |
+
|
1894 |
+
################
|
1895 |
+
##### fp16 #####
|
1896 |
+
################
|
1897 |
+
|
1898 |
+
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
|
1899 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1900 |
+
# 0 256.0 0.032518 0.035472 0.029939
|
1901 |
+
# 1 512.0 0.054266 0.087841 0.054320
|
1902 |
+
# 2 1024.0 0.133447 0.263090 0.102045
|
1903 |
+
# 3 2048.0 0.384615 1.023293 0.201763
|
1904 |
+
# 4 4096.0 1.300890 4.023936 0.449555
|
1905 |
+
# 5 8192.0 4.774144 15.816704 1.150854
|
1906 |
+
# 6 16384.0 18.220032 62.771198 3.356001
|
1907 |
+
# 7 32768.0 71.405571 250.273788 10.976142
|
1908 |
+
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
|
1909 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1910 |
+
# 0 256.0 0.083342 0.069742 0.079496
|
1911 |
+
# 1 512.0 0.159894 0.170995 0.151705
|
1912 |
+
# 2 1024.0 0.386071 0.522407 0.331443
|
1913 |
+
# 3 2048.0 1.067715 1.737333 0.715248
|
1914 |
+
# 4 4096.0 3.382731 6.219520 1.597457
|
1915 |
+
# 5 8192.0 11.857793 23.560448 3.879035
|
1916 |
+
# 6 16384.0 44.422142 91.251709 10.626843
|
1917 |
+
# 7 32768.0 175.011841 359.473145 32.340992
|
1918 |
+
|
1919 |
+
|
1920 |
+
################
|
1921 |
+
##### bf16 #####
|
1922 |
+
################
|
1923 |
+
|
1924 |
+
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
|
1925 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1926 |
+
# 0 256.0 0.037636 0.035902 0.031512
|
1927 |
+
# 1 512.0 0.058591 0.087229 0.058125
|
1928 |
+
# 2 1024.0 0.143337 0.263919 0.108443
|
1929 |
+
# 3 2048.0 0.414458 1.025985 0.214114
|
1930 |
+
# 4 4096.0 1.390841 4.020010 0.480550
|
1931 |
+
# 5 8192.0 5.067938 15.808171 1.230874
|
1932 |
+
# 6 16384.0 19.442280 62.765057 3.597274
|
1933 |
+
# 7 32768.0 75.501572 250.443771 11.768959
|
1934 |
+
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
|
1935 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1936 |
+
# 0 256.0 0.084404 0.070663 0.082613
|
1937 |
+
# 1 512.0 0.161510 0.172882 0.157661
|
1938 |
+
# 2 1024.0 0.388954 0.526047 0.339855
|
1939 |
+
# 3 2048.0 1.075814 1.736057 0.732420
|
1940 |
+
# 4 4096.0 3.401622 6.221376 1.636039
|
1941 |
+
# 5 8192.0 11.915136 23.483391 3.968725
|
1942 |
+
# 6 16384.0 44.660225 91.302910 10.857130
|
1943 |
+
# 7 32768.0 175.038467 359.048187 32.778240
|