Upload folder using huggingface_hub (#1)
Browse files- 280a4464349485d578cd0583f7e14d683f0783e5ead54f847d48a475f96be979 (4514f143501729dad06b8894f8e96d3f1f665356)
- 25300cc843729640862f53f6b2617077f0f3d295d8b67795b949167bb850fbe6 (0fd61643772d9961673f09e163a0501900f20bcf)
- README.md +85 -0
- cl100k_base.tiktoken +0 -0
- config.json +160 -0
- configuration_phi3_small.py +250 -0
- generation_config.json +10 -0
- model.safetensors +3 -0
- modeling_phi3_small.py +1140 -0
- positional_embedding.py +288 -0
- smash_config.json +31 -0
- special_tokens_map.json +5 -0
- tokenization_phi3_small.py +313 -0
- tokenizer_config.json +20 -0
- triton_blocksparse_attention_layer.py +176 -0
- triton_flash_blocksparse_attn.py +1947 -0
README.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
thumbnail: "https://assets-global.website-files.com/646b351987a8d8ce158d1940/64ec9e96b4334c0e1ac41504_Logo%20with%20white%20text.svg"
|
3 |
+
base_model: numind/NuExtract-large
|
4 |
+
metrics:
|
5 |
+
- memory_disk
|
6 |
+
- memory_inference
|
7 |
+
- inference_latency
|
8 |
+
- inference_throughput
|
9 |
+
- inference_CO2_emissions
|
10 |
+
- inference_energy_consumption
|
11 |
+
tags:
|
12 |
+
- pruna-ai
|
13 |
+
---
|
14 |
+
<!-- header start -->
|
15 |
+
<!-- 200823 -->
|
16 |
+
<div style="width: auto; margin-left: auto; margin-right: auto">
|
17 |
+
<a href="https://www.pruna.ai/" target="_blank" rel="noopener noreferrer">
|
18 |
+
<img src="https://i.imgur.com/eDAlcgk.png" alt="PrunaAI" style="width: 100%; min-width: 400px; display: block; margin: auto;">
|
19 |
+
</a>
|
20 |
+
</div>
|
21 |
+
<!-- header end -->
|
22 |
+
|
23 |
+
[![Twitter](https://img.shields.io/twitter/follow/PrunaAI?style=social)](https://twitter.com/PrunaAI)
|
24 |
+
[![GitHub](https://img.shields.io/github/followers/PrunaAI?label=Follow%20%40PrunaAI&style=social)](https://github.com/PrunaAI)
|
25 |
+
[![LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue)](https://www.linkedin.com/company/93832878/admin/feed/posts/?feedType=following)
|
26 |
+
[![Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?style=social&logo=discord)](https://discord.gg/rskEr4BZJx)
|
27 |
+
|
28 |
+
# Simply make AI models cheaper, smaller, faster, and greener!
|
29 |
+
|
30 |
+
- Give a thumbs up if you like this model!
|
31 |
+
- Contact us and tell us which model to compress next [here](https://www.pruna.ai/contact).
|
32 |
+
- Request access to easily compress your *own* AI models [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
|
33 |
+
- Read the documentations to know more [here](https://pruna-ai-pruna.readthedocs-hosted.com/en/latest/)
|
34 |
+
- Join Pruna AI community on Discord [here](https://discord.gg/CP4VSgck) to share feedback/suggestions or get help.
|
35 |
+
|
36 |
+
## Results
|
37 |
+
|
38 |
+
![image info](./plots.png)
|
39 |
+
|
40 |
+
**Frequently Asked Questions**
|
41 |
+
- ***How does the compression work?*** The model is compressed with llm-int8.
|
42 |
+
- ***How does the model quality change?*** The quality of the model output might vary compared to the base model.
|
43 |
+
- ***How is the model efficiency evaluated?*** These results were obtained on HARDWARE_NAME with configuration described in `model/smash_config.json` and are obtained after a hardware warmup. The smashed model is directly compared to the original base model. Efficiency results may vary in other settings (e.g. other hardware, image size, batch size, ...). We recommend to directly run them in the use-case conditions to know if the smashed model can benefit you.
|
44 |
+
- ***What is the model format?*** We use safetensors.
|
45 |
+
- ***What calibration data has been used?*** If needed by the compression method, we used WikiText as the calibration data.
|
46 |
+
- ***What is the naming convention for Pruna Huggingface models?*** We take the original model name and append "turbo", "tiny", or "green" if the smashed model has a measured inference speed, inference memory, or inference energy consumption which is less than 90% of the original base model.
|
47 |
+
- ***How to compress my own models?*** You can request premium access to more compression methods and tech support for your specific use-cases [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
|
48 |
+
- ***What are "first" metrics?*** Results mentioning "first" are obtained after the first run of the model. The first run might take more memory or be slower than the subsequent runs due cuda overheads.
|
49 |
+
- ***What are "Sync" and "Async" metrics?*** "Sync" metrics are obtained by syncing all GPU processes and stop measurement when all of them are executed. "Async" metrics are obtained without syncing all GPU processes and stop when the model output can be used by the CPU. We provide both metrics since both could be relevant depending on the use-case. We recommend to test the efficiency gains directly in your use-cases.
|
50 |
+
|
51 |
+
## Setup
|
52 |
+
|
53 |
+
You can run the smashed model with these steps:
|
54 |
+
|
55 |
+
0. Check requirements from the original repo numind/NuExtract-large installed. In particular, check python, cuda, and transformers versions.
|
56 |
+
1. Make sure that you have installed quantization related packages.
|
57 |
+
```bash
|
58 |
+
pip install transformers accelerate bitsandbytes>0.37.0
|
59 |
+
```
|
60 |
+
2. Load & run the model.
|
61 |
+
```python
|
62 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
63 |
+
|
64 |
+
|
65 |
+
model = AutoModelForCausalLM.from_pretrained("PrunaAI/numind-NuExtract-large-bnb-4bit-smashed", trust_remote_code=True, device_map='auto')
|
66 |
+
tokenizer = AutoTokenizer.from_pretrained("numind/NuExtract-large")
|
67 |
+
|
68 |
+
input_ids = tokenizer("What is the color of prunes?,", return_tensors='pt').to(model.device)["input_ids"]
|
69 |
+
|
70 |
+
outputs = model.generate(input_ids, max_new_tokens=216)
|
71 |
+
tokenizer.decode(outputs[0])
|
72 |
+
```
|
73 |
+
|
74 |
+
## Configurations
|
75 |
+
|
76 |
+
The configuration info are in `smash_config.json`.
|
77 |
+
|
78 |
+
## Credits & License
|
79 |
+
|
80 |
+
The license of the smashed model follows the license of the original model. Please check the license of the original model numind/NuExtract-large before using this model which provided the base model. The license of the `pruna-engine` is [here](https://pypi.org/project/pruna-engine/) on Pypi.
|
81 |
+
|
82 |
+
## Want to compress other models?
|
83 |
+
|
84 |
+
- Contact us and tell us which model to compress next [here](https://www.pruna.ai/contact).
|
85 |
+
- Request access to easily compress your own AI models [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
|
cl100k_base.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
config.json
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/ceph/hdd/staff/charpent/.cache/modelsjayrsbxckim1dlli",
|
3 |
+
"architectures": [
|
4 |
+
"Phi3SmallForCausalLM"
|
5 |
+
],
|
6 |
+
"attention_bias": false,
|
7 |
+
"attention_dropout_prob": 0.0,
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_phi3_small.Phi3SmallConfig",
|
10 |
+
"AutoModelForCausalLM": "modeling_phi3_small.Phi3SmallForCausalLM",
|
11 |
+
"AutoModelForSequenceClassification": "numind/NuExtract-large--modeling_phi3_small.Phi3SmallForSequenceClassification",
|
12 |
+
"AutoTokenizer": "numind/NuExtract-large--tokenization_phi3_small.Phi3SmallTokenizer"
|
13 |
+
},
|
14 |
+
"blocksparse_block_size": 64,
|
15 |
+
"blocksparse_homo_head_pattern": false,
|
16 |
+
"blocksparse_num_local_blocks": 16,
|
17 |
+
"blocksparse_triton_kernel_block_size": 64,
|
18 |
+
"blocksparse_vert_stride": 8,
|
19 |
+
"bos_token_id": 100257,
|
20 |
+
"dense_attention_every_n_layers": 2,
|
21 |
+
"dummy_token_indices": [
|
22 |
+
100256,
|
23 |
+
100258,
|
24 |
+
100259,
|
25 |
+
100260,
|
26 |
+
100264,
|
27 |
+
100265,
|
28 |
+
100267,
|
29 |
+
100268,
|
30 |
+
100269,
|
31 |
+
100270,
|
32 |
+
100271,
|
33 |
+
100272,
|
34 |
+
100273,
|
35 |
+
100274,
|
36 |
+
100275,
|
37 |
+
100276,
|
38 |
+
100277,
|
39 |
+
100278,
|
40 |
+
100279,
|
41 |
+
100280,
|
42 |
+
100281,
|
43 |
+
100282,
|
44 |
+
100283,
|
45 |
+
100284,
|
46 |
+
100285,
|
47 |
+
100286,
|
48 |
+
100287,
|
49 |
+
100288,
|
50 |
+
100289,
|
51 |
+
100290,
|
52 |
+
100291,
|
53 |
+
100292,
|
54 |
+
100293,
|
55 |
+
100294,
|
56 |
+
100295,
|
57 |
+
100296,
|
58 |
+
100297,
|
59 |
+
100298,
|
60 |
+
100299,
|
61 |
+
100300,
|
62 |
+
100301,
|
63 |
+
100302,
|
64 |
+
100303,
|
65 |
+
100304,
|
66 |
+
100305,
|
67 |
+
100306,
|
68 |
+
100307,
|
69 |
+
100308,
|
70 |
+
100309,
|
71 |
+
100310,
|
72 |
+
100311,
|
73 |
+
100312,
|
74 |
+
100313,
|
75 |
+
100314,
|
76 |
+
100315,
|
77 |
+
100316,
|
78 |
+
100317,
|
79 |
+
100318,
|
80 |
+
100319,
|
81 |
+
100320,
|
82 |
+
100321,
|
83 |
+
100322,
|
84 |
+
100323,
|
85 |
+
100324,
|
86 |
+
100325,
|
87 |
+
100326,
|
88 |
+
100327,
|
89 |
+
100328,
|
90 |
+
100329,
|
91 |
+
100330,
|
92 |
+
100331,
|
93 |
+
100332,
|
94 |
+
100333,
|
95 |
+
100334,
|
96 |
+
100335,
|
97 |
+
100336,
|
98 |
+
100337,
|
99 |
+
100338,
|
100 |
+
100339,
|
101 |
+
100340,
|
102 |
+
100341,
|
103 |
+
100342,
|
104 |
+
100343,
|
105 |
+
100344,
|
106 |
+
100345,
|
107 |
+
100346,
|
108 |
+
100347,
|
109 |
+
100348,
|
110 |
+
100349,
|
111 |
+
100350,
|
112 |
+
100351
|
113 |
+
],
|
114 |
+
"embedding_dropout_prob": 0.1,
|
115 |
+
"eos_token_id": 100257,
|
116 |
+
"ff_dim_multiplier": null,
|
117 |
+
"ff_intermediate_size": 14336,
|
118 |
+
"ffn_dropout_prob": 0.1,
|
119 |
+
"gegelu_limit": 20.0,
|
120 |
+
"gegelu_pad_to_256": true,
|
121 |
+
"hidden_act": "gegelu",
|
122 |
+
"hidden_size": 4096,
|
123 |
+
"initializer_range": 0.02,
|
124 |
+
"layer_norm_epsilon": 1e-05,
|
125 |
+
"max_position_embeddings": 8192,
|
126 |
+
"model_type": "phi3small",
|
127 |
+
"mup_attn_multiplier": 1.0,
|
128 |
+
"mup_embedding_multiplier": 10.0,
|
129 |
+
"mup_use_scaling": true,
|
130 |
+
"mup_width_multiplier": 8.0,
|
131 |
+
"num_attention_heads": 32,
|
132 |
+
"num_hidden_layers": 32,
|
133 |
+
"num_key_value_heads": 8,
|
134 |
+
"pad_sequence_to_multiple_of_64": true,
|
135 |
+
"quantization_config": {
|
136 |
+
"_load_in_4bit": true,
|
137 |
+
"_load_in_8bit": false,
|
138 |
+
"bnb_4bit_compute_dtype": "bfloat16",
|
139 |
+
"bnb_4bit_quant_storage": "uint8",
|
140 |
+
"bnb_4bit_quant_type": "fp4",
|
141 |
+
"bnb_4bit_use_double_quant": false,
|
142 |
+
"llm_int8_enable_fp32_cpu_offload": false,
|
143 |
+
"llm_int8_has_fp16_weight": false,
|
144 |
+
"llm_int8_skip_modules": [
|
145 |
+
"lm_head"
|
146 |
+
],
|
147 |
+
"llm_int8_threshold": 6.0,
|
148 |
+
"load_in_4bit": true,
|
149 |
+
"load_in_8bit": false,
|
150 |
+
"quant_method": "bitsandbytes"
|
151 |
+
},
|
152 |
+
"reorder_and_upcast_attn": false,
|
153 |
+
"rope_embedding_base": 1000000,
|
154 |
+
"rope_position_scale": 1.0,
|
155 |
+
"rope_scaling": null,
|
156 |
+
"torch_dtype": "float16",
|
157 |
+
"transformers_version": "4.42.4",
|
158 |
+
"use_cache": true,
|
159 |
+
"vocab_size": 100352
|
160 |
+
}
|
configuration_phi3_small.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from typing import Any, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
from functools import cached_property
|
22 |
+
|
23 |
+
""" Phi3Small model configuration """
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
def next_mult(x, y):
|
28 |
+
return (x + y - 1) // y * y
|
29 |
+
|
30 |
+
class Phi3SmallConfig(PretrainedConfig):
|
31 |
+
"""
|
32 |
+
This is the configuration class to store the configuration of a `Phi3Small` model. It is used to
|
33 |
+
instantiate a Phi-3-small model according to the specified arguments, defining the model architecture.
|
34 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the Phi-3-small
|
35 |
+
[phi3](https://arxiv.org/pdf/2404.14219) architecture.
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 100352):
|
43 |
+
Vocabulary size of the Phi3Small model. Defines the number of different tokens that can be represented by the
|
44 |
+
`inputs_ids` passed when calling `Phi3Small`.
|
45 |
+
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
46 |
+
The maximum sequence length that this model might safely be used with.
|
47 |
+
rope_embedding_base (`float`, *optional*, defaults to 10^6):
|
48 |
+
The base value for the RoPE (Relative Position Encoding) embedding.
|
49 |
+
rope_position_scale (`float`, *optional*, defaults to 1.0):
|
50 |
+
The scale factor for the RoPE position encoding.
|
51 |
+
rope_scaling (`Optional[Dict[str, Union[float, List[float], int]]]`, *optional*, defaults to None):
|
52 |
+
The scaling configuration used for LongRoPE.
|
53 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
54 |
+
The size of the hidden layers in the model.
|
55 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
56 |
+
The number of layers in the model.
|
57 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
58 |
+
The number of query heads in the model.
|
59 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
60 |
+
The number of key-value heads in the model.
|
61 |
+
hidden_act (`str`, *optional*, defaults to "gegelu"):
|
62 |
+
The activation function used in the model.
|
63 |
+
gegelu_limit (`float`, *optional*, defaults to 20.0):
|
64 |
+
The limit value for the GELU activation function (for numerical stability).
|
65 |
+
gegelu_pad_to_256 (`bool`, *optional*, defaults to True):
|
66 |
+
Whether to pad the intermediate size to a multiple of 256 (for faster matmul ops).
|
67 |
+
ff_dim_multiplier (`Optional[int]`, *optional*, defaults to None):
|
68 |
+
The dimension multiplier for the feed-forward layers.
|
69 |
+
ff_intermediate_size (`Optional[int]`, *optional*, defaults to 14336):
|
70 |
+
The intermediate size for the feed-forward layers.
|
71 |
+
One of `ff_dim_multiplier` or `ff_intermediate_size` must be specified.
|
72 |
+
blocksparse_homo_head_pattern (`bool`, *optional*, defaults to False):
|
73 |
+
Whether to use a homogeneous head pattern for block-sparse attention.
|
74 |
+
blocksparse_block_size (`int`, *optional*, defaults to 64):
|
75 |
+
The block size for block-sparse attention.
|
76 |
+
blocksparse_num_local_blocks (`int`, *optional*, defaults to 16):
|
77 |
+
The number of local blocks for block-sparse attention.
|
78 |
+
The local window used in blocksparse equals `blocksparse_num_local_blocks * blocksparse_block_size`
|
79 |
+
blocksparse_vert_stride (`int`, *optional*, defaults to 8):
|
80 |
+
The vertical stride for block-sparse attention.
|
81 |
+
blocksparse_triton_kernel_block_size (`int`, *optional*, defaults to 64):
|
82 |
+
The kernel block size for block-sparse attention.
|
83 |
+
dense_attention_every_n_layers (`Optional[int]`, *optional*, defaults to 2):
|
84 |
+
The frequency of all dense attention layers in the model
|
85 |
+
embedding_dropout_prob (`float`, *optional*, defaults to 0.1):
|
86 |
+
The dropout probability for the embedding layer.
|
87 |
+
attention_dropout_prob (`float`, *optional*, defaults to 0.0):
|
88 |
+
The dropout probability for the attention layers.
|
89 |
+
ffn_dropout_prob (`float`, *optional*, defaults to 0.1):
|
90 |
+
The dropout probability for the feed-forward layers.
|
91 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
92 |
+
The epsilon value for layer normalization.
|
93 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
94 |
+
The range for weight initialization.
|
95 |
+
mup_use_scaling (`bool`, *optional*, defaults to True):
|
96 |
+
Whether to use scaling for MuP parameters (see: https://arxiv.org/abs/2203.03466).
|
97 |
+
mup_width_multiplier (`bool`, *optional*, defaults to 8.0):
|
98 |
+
The width multiplier for MuP.
|
99 |
+
mup_embedding_multiplier (`bool`, *optional*, defaults to 10.0):
|
100 |
+
The embedding multiplier for MuP.
|
101 |
+
mup_attn_multiplier (`bool`, *optional*, defaults to 1.0):
|
102 |
+
The attention multiplier for MuP.
|
103 |
+
use_cache (`bool`, *optional*, defaults to True):
|
104 |
+
Whether to use cache for the model.
|
105 |
+
bos_token_id (`int`, *optional*, defaults to 100257):
|
106 |
+
The token ID for the beginning of sentence.
|
107 |
+
eos_token_id (`int`, *optional*, defaults to 100257):
|
108 |
+
The token ID for the end of sentence.
|
109 |
+
reorder_and_upcast_attn (`bool`, *optional*, defaults to False):
|
110 |
+
Whether to reorder and upcast attention.
|
111 |
+
pad_sequence_to_multiple_of_64 (`bool`, *optional*, defaults to True):
|
112 |
+
Whether to pad the sequence length to a multiple of 64.
|
113 |
+
**kwargs:
|
114 |
+
Additional keyword arguments.
|
115 |
+
|
116 |
+
Example:
|
117 |
+
|
118 |
+
```python
|
119 |
+
>>> from transformers import Phi3SmallConfig, Phi3SmallModel
|
120 |
+
|
121 |
+
>>> # Initializing a Phi3Small configuration
|
122 |
+
>>> configuration = Phi3SmallConfig()
|
123 |
+
|
124 |
+
>>> # Initializing a model (with random weights) from the configuration
|
125 |
+
>>> model = Phi3SmallModel(configuration)
|
126 |
+
|
127 |
+
>>> # Accessing the model configuration
|
128 |
+
>>> configuration = model.config
|
129 |
+
```
|
130 |
+
"""
|
131 |
+
|
132 |
+
model_type = "phi3small"
|
133 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
134 |
+
|
135 |
+
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
# General information about the model
|
139 |
+
vocab_size: int =100352,
|
140 |
+
max_position_embeddings: int = 8192,
|
141 |
+
# RoPE Related Parameters
|
142 |
+
rope_embedding_base: float = 10**6,
|
143 |
+
rope_position_scale: float = 1.0,
|
144 |
+
rope_scaling: Optional[Dict[str, Union[float, List[float], int]]] = None,
|
145 |
+
# General Model Parameters
|
146 |
+
hidden_size: int = 4096,
|
147 |
+
num_hidden_layers: int = 32,
|
148 |
+
# KV Shared Attention Configurations
|
149 |
+
num_attention_heads: int = 32,
|
150 |
+
num_key_value_heads: int = 8,
|
151 |
+
# GEGELU Related Parameters
|
152 |
+
hidden_act: str = "gegelu",
|
153 |
+
gegelu_limit: float = 20.0,
|
154 |
+
gegelu_pad_to_256: bool = True,
|
155 |
+
ff_dim_multiplier: Optional[int] = None,
|
156 |
+
ff_intermediate_size: Optional[int] = 14336,
|
157 |
+
# Block Sparse Attention Parameters
|
158 |
+
blocksparse_homo_head_pattern: bool = False,
|
159 |
+
blocksparse_block_size: int = 64,
|
160 |
+
blocksparse_num_local_blocks: int = 16,
|
161 |
+
blocksparse_vert_stride: int = 8,
|
162 |
+
blocksparse_triton_kernel_block_size: int = 64,
|
163 |
+
# Frequency of block-sparsity
|
164 |
+
dense_attention_every_n_layers: Optional[int] = 2,
|
165 |
+
# Reegularization parameters
|
166 |
+
embedding_dropout_prob: float =0.1,
|
167 |
+
attention_dropout_prob: float = 0.0,
|
168 |
+
ffn_dropout_prob: float = 0.1,
|
169 |
+
layer_norm_epsilon=1e-5,
|
170 |
+
initializer_range=0.02,
|
171 |
+
# MuP parameters
|
172 |
+
mup_use_scaling: bool = True,
|
173 |
+
mup_width_multiplier: bool = 8.0,
|
174 |
+
mup_embedding_multiplier: bool = 10.0,
|
175 |
+
mup_attn_multiplier: bool =1.0,
|
176 |
+
use_cache=True,
|
177 |
+
# The model does not have a bos token id
|
178 |
+
# However, in order for some of the downstream libraries to not break
|
179 |
+
# we set this to be the same as the eos_token_id
|
180 |
+
bos_token_id: int = 100257,
|
181 |
+
eos_token_id: int = 100257,
|
182 |
+
reorder_and_upcast_attn=False,
|
183 |
+
# Configuration to pad sequence length to a multiple of 64
|
184 |
+
pad_sequence_to_multiple_of_64: bool = True,
|
185 |
+
**kwargs,
|
186 |
+
):
|
187 |
+
self.vocab_size = vocab_size
|
188 |
+
self.max_position_embeddings = max_position_embeddings
|
189 |
+
self.rope_embedding_base = rope_embedding_base
|
190 |
+
self.rope_position_scale = rope_position_scale
|
191 |
+
self.rope_scaling = rope_scaling
|
192 |
+
self.hidden_size = hidden_size
|
193 |
+
# QK Shared Attention
|
194 |
+
self.num_hidden_layers = num_hidden_layers
|
195 |
+
self.num_attention_heads = num_attention_heads
|
196 |
+
self.num_key_value_heads = num_key_value_heads
|
197 |
+
# Block Sparse Attention Pattern
|
198 |
+
self.blocksparse_homo_head_pattern = blocksparse_homo_head_pattern
|
199 |
+
self.blocksparse_block_size = blocksparse_block_size
|
200 |
+
self.blocksparse_num_local_blocks = blocksparse_num_local_blocks
|
201 |
+
self.blocksparse_vert_stride = blocksparse_vert_stride
|
202 |
+
self.blocksparse_triton_kernel_block_size = blocksparse_triton_kernel_block_size
|
203 |
+
# Frequency of block sparsity
|
204 |
+
self.dense_attention_every_n_layers = dense_attention_every_n_layers
|
205 |
+
# Activation function
|
206 |
+
self.hidden_act = hidden_act
|
207 |
+
self.gegelu_limit = gegelu_limit
|
208 |
+
self.gegelu_pad_to_256 = gegelu_pad_to_256
|
209 |
+
self.ff_dim_multiplier = ff_dim_multiplier
|
210 |
+
self.ff_intermediate_size = ff_intermediate_size
|
211 |
+
if self.ff_dim_multiplier is None and self.ff_intermediate_size is None:
|
212 |
+
raise ValueError(f"Cannot have both {self.ff_dim_multiplier} and {self.ff_intermediate_size} as None")
|
213 |
+
if self.ff_dim_multiplier is not None and self.ff_intermediate_size is not None:
|
214 |
+
raise ValueError(f"Cannot specify both {self.ff_dim_multiplier} and {self.ff_intermediate_size}.")
|
215 |
+
# General regularization
|
216 |
+
self.embedding_dropout_prob = embedding_dropout_prob
|
217 |
+
self.attention_dropout_prob = attention_dropout_prob
|
218 |
+
self.ffn_dropout_prob = ffn_dropout_prob
|
219 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
220 |
+
self.initializer_range = initializer_range
|
221 |
+
# MuP parameters
|
222 |
+
self.mup_use_scaling = mup_use_scaling
|
223 |
+
self.mup_width_multiplier = mup_width_multiplier
|
224 |
+
self.mup_embedding_multiplier = mup_embedding_multiplier
|
225 |
+
self.mup_attn_multiplier = mup_attn_multiplier
|
226 |
+
self.use_cache = use_cache
|
227 |
+
|
228 |
+
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
229 |
+
self.pad_sequence_to_multiple_of_64 = pad_sequence_to_multiple_of_64
|
230 |
+
|
231 |
+
self.bos_token_id = bos_token_id
|
232 |
+
self.eos_token_id = eos_token_id
|
233 |
+
|
234 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
235 |
+
|
236 |
+
@cached_property
|
237 |
+
def dummy_token_indices(self) -> List[int]:
|
238 |
+
# Importing here to avoid circular imports
|
239 |
+
from .tokenization_phi3_small import Phi3SmallTokenizer
|
240 |
+
tokenizer = Phi3SmallTokenizer()
|
241 |
+
return tokenizer.dummy_token_indices
|
242 |
+
|
243 |
+
@property
|
244 |
+
def intermediate_size(self) -> int:
|
245 |
+
if self.ff_intermediate_size is not None:
|
246 |
+
return self.ff_intermediate_size
|
247 |
+
intermediate_size = (self.ff_dim_multiplier) * (self.hidden_size // 3) * 2
|
248 |
+
if self.gegelu_pad_to_256:
|
249 |
+
intermediate_size = next_mult(intermediate_size, 256)
|
250 |
+
return intermediate_size
|
generation_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 100257,
|
4 |
+
"eos_token_id": [
|
5 |
+
100257,
|
6 |
+
100266
|
7 |
+
],
|
8 |
+
"max_new_tokens": 2000,
|
9 |
+
"transformers_version": "4.42.4"
|
10 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a8ed2260866cb4627489b8904992bda101277c420a15fc226ace769f418b4fd
|
3 |
+
size 4751891888
|
modeling_phi3_small.py
ADDED
@@ -0,0 +1,1140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Dict, Optional, List, Tuple, Union
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
from transformers.modeling_outputs import SequenceClassifierOutputWithPast, CausalLMOutputWithPast, BaseModelOutputWithPast
|
11 |
+
from transformers.modeling_utils import PreTrainedModel
|
12 |
+
from transformers.utils import logging
|
13 |
+
|
14 |
+
from transformers.cache_utils import Cache, DynamicCache
|
15 |
+
|
16 |
+
from .triton_flash_blocksparse_attn import BlockSparseParams
|
17 |
+
from .triton_blocksparse_attention_layer import BlockSparseAttentionLayer
|
18 |
+
from .positional_embedding import RotaryEmbedding
|
19 |
+
|
20 |
+
from .configuration_phi3_small import Phi3SmallConfig
|
21 |
+
|
22 |
+
# Flash Attention Related Imports
|
23 |
+
is_flash_attention_available = False
|
24 |
+
try:
|
25 |
+
import flash_attn
|
26 |
+
if int(flash_attn.__version__.split('.')[0]) < 2:
|
27 |
+
from flash_attn.flash_attn_interface import (
|
28 |
+
flash_attn_func,
|
29 |
+
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
30 |
+
)
|
31 |
+
|
32 |
+
# rename `max_seqlen`
|
33 |
+
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, **kwargs):
|
34 |
+
return flash_attn_func(qkv, cu_seqlens, dropout_p=dropout_p, max_s=max_seqlen, **kwargs)
|
35 |
+
|
36 |
+
else:
|
37 |
+
from flash_attn.flash_attn_interface import (
|
38 |
+
flash_attn_varlen_kvpacked_func,
|
39 |
+
)
|
40 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
41 |
+
is_flash_attention_available = True
|
42 |
+
except ImportError:
|
43 |
+
pass
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__)
|
46 |
+
|
47 |
+
LegacyCache = Tuple[Tuple[torch.FloatTensor]]
|
48 |
+
|
49 |
+
# Taken from https://github.com/allenai/allennlp/blob/main/allennlp/nn/util.py
|
50 |
+
def info_value_of_dtype(dtype: torch.dtype):
|
51 |
+
"""
|
52 |
+
Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool.
|
53 |
+
"""
|
54 |
+
if dtype == torch.bool:
|
55 |
+
raise TypeError("Does not support torch.bool")
|
56 |
+
elif dtype.is_floating_point:
|
57 |
+
return torch.finfo(dtype)
|
58 |
+
else:
|
59 |
+
return torch.iinfo(dtype)
|
60 |
+
|
61 |
+
|
62 |
+
# Taken from https://github.com/allenai/allennlp/blob/main/allennlp/nn/util.py
|
63 |
+
def min_value_of_dtype(dtype: torch.dtype):
|
64 |
+
"""
|
65 |
+
Returns the minimum value of a given PyTorch data type. Does not allow torch.bool.
|
66 |
+
"""
|
67 |
+
return info_value_of_dtype(dtype).min
|
68 |
+
|
69 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
70 |
+
def _get_unpad_data(attention_mask):
|
71 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
72 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
73 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
74 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
75 |
+
return (
|
76 |
+
indices,
|
77 |
+
cu_seqlens,
|
78 |
+
max_seqlen_in_batch,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
@torch.jit.script
|
83 |
+
def quick_gelu(x):
|
84 |
+
return x * torch.sigmoid(1.702 * x)
|
85 |
+
|
86 |
+
|
87 |
+
@torch.jit.script
|
88 |
+
def gegelu(input, limit: Optional[float] = None):
|
89 |
+
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
90 |
+
if limit is not None:
|
91 |
+
a_gelu = torch.where(
|
92 |
+
torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
|
93 |
+
)
|
94 |
+
a_linear = torch.where(
|
95 |
+
torch.isinf(a_linear), a_linear, a_linear.clamp(min=-limit, max=limit)
|
96 |
+
)
|
97 |
+
out_gelu = quick_gelu(a_gelu)
|
98 |
+
return out_gelu * (a_linear + 1)
|
99 |
+
|
100 |
+
def collapse_first_n_dims(x: torch.Tensor, n: int) -> torch.Tensor:
|
101 |
+
"""
|
102 |
+
Collapse the first `n` dimensions of a tensor into a single dimension.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
x (torch.Tensor): The input tensor.
|
106 |
+
n (int): The number of dimensions to collapse.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
torch.Tensor: The output tensor.
|
110 |
+
"""
|
111 |
+
return x.view(-1, *x.shape[n:])
|
112 |
+
|
113 |
+
def pad_tensor_to_next_mult_of(
|
114 |
+
tensor: torch.Tensor,
|
115 |
+
dim: int,
|
116 |
+
n: int,
|
117 |
+
) -> Tuple[torch.Tensor, int]:
|
118 |
+
"""
|
119 |
+
Pads a tensor along a specified dimension to the next multiple of a given number.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
tensor (torch.Tensor): The input tensor.
|
123 |
+
dim (int): The dimension along which to pad the tensor.
|
124 |
+
n (int): The number to pad the tensor to the next multiple of.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Tuple[torch.Tensor, int]: A tuple containing the padded tensor and the amount of padding added.
|
128 |
+
"""
|
129 |
+
residual = tensor.size(dim) % n
|
130 |
+
if residual == 0:
|
131 |
+
return tensor, 0
|
132 |
+
padding = n - residual
|
133 |
+
padding_tensor = torch.zeros((*tensor.size()[:dim], padding, *tensor.size()[dim + 1:]), device=tensor.device, dtype=tensor.dtype)
|
134 |
+
return torch.cat([tensor, padding_tensor], dim=dim), padding
|
135 |
+
|
136 |
+
def strip_padding_from_tensor(
|
137 |
+
tensor: torch.Tensor,
|
138 |
+
dim: int,
|
139 |
+
residual: int,
|
140 |
+
) -> torch.Tensor:
|
141 |
+
"""
|
142 |
+
Removes padding from a tensor along a specified dimension.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
tensor (torch.Tensor): The input tensor.
|
146 |
+
dim (int): The dimension along which to remove padding.
|
147 |
+
residual (int): The amount of padding to remove.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
torch.Tensor: The tensor with padding removed along the specified dimension.
|
151 |
+
"""
|
152 |
+
return torch.narrow(tensor, dim, 0, tensor.size(dim) - residual)
|
153 |
+
|
154 |
+
class Phi3SmallMLP(nn.Module):
|
155 |
+
def __init__(self, config: Phi3SmallConfig):
|
156 |
+
super().__init__()
|
157 |
+
self.config = config
|
158 |
+
assert self.config.hidden_act == "gegelu", "Only `gegelu` is supported for the Phi-3-small model .."
|
159 |
+
self.hidden_size = config.hidden_size
|
160 |
+
self.gegelu_limit = config.gegelu_limit
|
161 |
+
self.intermediate_size = config.intermediate_size
|
162 |
+
|
163 |
+
self.up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size)
|
164 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
|
165 |
+
self.dropout = nn.Dropout(config.ffn_dropout_prob)
|
166 |
+
|
167 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
168 |
+
return self.dropout(
|
169 |
+
self.down_proj(
|
170 |
+
gegelu(self.up_proj(x), limit=self.gegelu_limit)
|
171 |
+
)
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
class Phi3SmallSelfAttention(nn.Module):
|
176 |
+
def __init__(self, config: Phi3SmallConfig, layer_idx: Optional[int] = None) -> None:
|
177 |
+
super().__init__()
|
178 |
+
self.config = config
|
179 |
+
self.layer_idx = layer_idx
|
180 |
+
if layer_idx is None:
|
181 |
+
logger.warning_once(
|
182 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
183 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
184 |
+
"when creating this class."
|
185 |
+
)
|
186 |
+
|
187 |
+
self.hidden_size = config.hidden_size
|
188 |
+
# Number of Query Heads
|
189 |
+
self.num_heads = config.num_attention_heads
|
190 |
+
self.head_dim = self.hidden_size // self.num_heads
|
191 |
+
# Number of Key Value Heads
|
192 |
+
self.num_key_value_heads = config.num_key_value_heads
|
193 |
+
self.num_q_per_kv = self.num_heads // self.num_key_value_heads
|
194 |
+
self.max_position_embeddings = config.max_position_embeddings
|
195 |
+
self.rope_embedding_base = config.rope_embedding_base
|
196 |
+
self.rope_position_scale = config.rope_position_scale
|
197 |
+
self.is_causal = True
|
198 |
+
|
199 |
+
self.attention_dropout_rate = config.attention_dropout_prob
|
200 |
+
|
201 |
+
norm_factor = None
|
202 |
+
if config.mup_use_scaling:
|
203 |
+
norm_factor = self.head_dim / config.mup_attn_multiplier
|
204 |
+
else:
|
205 |
+
norm_factor = math.sqrt(self.head_dim)
|
206 |
+
self.softmax_scale = 1.0 / norm_factor
|
207 |
+
|
208 |
+
self.query_key_value = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim)
|
209 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
210 |
+
|
211 |
+
self.blocksparse_params = None
|
212 |
+
# layer_idx is 0 indexed because that's what the KV Cache expects.
|
213 |
+
if self.config.dense_attention_every_n_layers and ((self.layer_idx + 1) % self.config.dense_attention_every_n_layers == 0):
|
214 |
+
logger.info(
|
215 |
+
f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
|
216 |
+
f"{self.config.dense_attention_every_n_layers}"
|
217 |
+
)
|
218 |
+
assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
|
219 |
+
else:
|
220 |
+
# BlockSparse related Parameters
|
221 |
+
self.blocksparse_params = BlockSparseParams.from_config(config)
|
222 |
+
|
223 |
+
if self.blocksparse:
|
224 |
+
active_head_range = None
|
225 |
+
"""
|
226 |
+
... note(bapatra)::
|
227 |
+
|
228 |
+
In case of tensor parallelism and while using the heterogeneous head patterns,
|
229 |
+
the active head range needs to be modified based on the tensor parallel rank
|
230 |
+
and the tensor parallel world size.
|
231 |
+
|
232 |
+
This is because in the case of heterogeneous head patterns, the kernel needs to know
|
233 |
+
which head is on which device, so that it can pick the corresponding blocksparse head
|
234 |
+
pattern correctly.
|
235 |
+
|
236 |
+
Example:
|
237 |
+
```python
|
238 |
+
|
239 |
+
if not self.blocksparse_params.homo_head_pattern:
|
240 |
+
tp_rank = torch.distributed.get_rank() % tp_world_size
|
241 |
+
num_heads_per_partition = num_heads // tp_world_size
|
242 |
+
active_head_range = (tp_rank * num_heads_per_partition, (tp_rank + 1) * num_heads_per_partition)
|
243 |
+
|
244 |
+
```
|
245 |
+
|
246 |
+
"""
|
247 |
+
|
248 |
+
self._blocksparse_layer = BlockSparseAttentionLayer(
|
249 |
+
n_heads=self.num_heads,
|
250 |
+
max_seq_len=self.max_position_embeddings,
|
251 |
+
sparse_block_size=self.blocksparse_params.block_size,
|
252 |
+
local_blocks=self.blocksparse_params.num_local_blocks,
|
253 |
+
vert_stride=self.blocksparse_params.vert_stride,
|
254 |
+
kernel_block_size=self.blocksparse_params.kernel_block_size,
|
255 |
+
homo_head=self.blocksparse_params.homo_head_pattern,
|
256 |
+
active_head_range=active_head_range,
|
257 |
+
)
|
258 |
+
self.rotary_emb = RotaryEmbedding.from_config(config)
|
259 |
+
|
260 |
+
|
261 |
+
@property
|
262 |
+
def blocksparse(self):
|
263 |
+
return self.blocksparse_params is not None
|
264 |
+
|
265 |
+
def _split_heads(self, mixed_x_layer: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
266 |
+
bs, sq, _ = mixed_x_layer.size()
|
267 |
+
r"""
|
268 |
+
The main idea is that we group tensors as
|
269 |
+
[bs, sq, (q00, q01, ... q0m, k0, v0), (q10, q11, ... q1m, k1, v1), ... (qn0, qn1, ... qnm, kn, vn)]
|
270 |
+
That ways, when the MP column sharding happens, this tensor will be sharded keeping all the
|
271 |
+
queries and keys intact. In order to get the correct qkv, we first break into groups, and then
|
272 |
+
index into the groups.
|
273 |
+
"""
|
274 |
+
|
275 |
+
intermediate_shape = (bs, sq, -1, (self.num_q_per_kv + 2), self.head_dim)
|
276 |
+
mixed_x_layer = mixed_x_layer.view(*intermediate_shape)
|
277 |
+
q = mixed_x_layer[:, :, :, :-2]
|
278 |
+
k = mixed_x_layer[:, :, :, [-2]]
|
279 |
+
v = mixed_x_layer[:, :, :, [-1]]
|
280 |
+
q, k, v = [
|
281 |
+
rearrange(
|
282 |
+
x,
|
283 |
+
"bs sq group nh hn -> bs sq (group nh) hn"
|
284 |
+
) for x in (q, k, v)
|
285 |
+
]
|
286 |
+
return q, k, v
|
287 |
+
|
288 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._unpad_input
|
289 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
290 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
291 |
+
|
292 |
+
|
293 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
294 |
+
|
295 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
296 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
297 |
+
|
298 |
+
if query_length == kv_seq_len:
|
299 |
+
query_layer = index_first_axis(
|
300 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
301 |
+
)
|
302 |
+
cu_seqlens_q = cu_seqlens_k
|
303 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
304 |
+
indices_q = indices_k
|
305 |
+
elif query_length == 1:
|
306 |
+
max_seqlen_in_batch_q = 1
|
307 |
+
cu_seqlens_q = torch.arange(
|
308 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
309 |
+
) # There is a memcpy here, that is very bad.
|
310 |
+
indices_q = cu_seqlens_q[:-1]
|
311 |
+
query_layer = query_layer.squeeze(1)
|
312 |
+
else:
|
313 |
+
# The -q_len: slice assumes left padding.
|
314 |
+
attention_mask = attention_mask[:, -query_length:]
|
315 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
316 |
+
|
317 |
+
return (
|
318 |
+
query_layer,
|
319 |
+
key_layer,
|
320 |
+
value_layer,
|
321 |
+
indices_q,
|
322 |
+
(cu_seqlens_q, cu_seqlens_k),
|
323 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
324 |
+
)
|
325 |
+
|
326 |
+
def _apply_blocksparse_attention(
|
327 |
+
self,
|
328 |
+
q: torch.Tensor,
|
329 |
+
k: torch.Tensor,
|
330 |
+
v: torch.Tensor,
|
331 |
+
attention_mask: Optional[torch.LongTensor],
|
332 |
+
return_attention_probs: bool = False,
|
333 |
+
) -> torch.Tensor:
|
334 |
+
"""
|
335 |
+
Applies blocksparse attention to the input tensors.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
q (torch.Tensor): The query tensor of shape (bs, nqp, seq_len, hn).
|
339 |
+
k (torch.Tensor): The key tensor of shape (bs, nkp, seq_len, hn).
|
340 |
+
v (torch.Tensor): The value tensor of shape (bs, nkp, seq_len, hn).
|
341 |
+
attention_mask (Optional[torch.LongTensor]): The attention mask tensor of shape (bs, seq_len).
|
342 |
+
return_attention_probs (bool, optional): Whether to return attention probabilities. Defaults to False.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
torch.Tensor: The context layer tensor of shape (bs, nqp, seq_len, hn).
|
346 |
+
"""
|
347 |
+
assert not return_attention_probs, "return_attention_probs is not supported for blocksparse attention"
|
348 |
+
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
349 |
+
# shape: (bs, nqp, seq_len, hn)
|
350 |
+
if torch.is_grad_enabled():
|
351 |
+
# Training or non-batched inference
|
352 |
+
context_layer = self._blocksparse_layer(
|
353 |
+
q=q, k=k, v=v, sm_scale=self.softmax_scale
|
354 |
+
)
|
355 |
+
elif attention_mask is None:
|
356 |
+
if q.size(0) != 1:
|
357 |
+
logger.warning_once(
|
358 |
+
"You are attempting to do batched inference without passing the attention mask.\n"
|
359 |
+
"This is okay if you are running loglikelihood requests. However, if you want to do generation, "
|
360 |
+
"this probably won't work as expected. Please pass the attention mask to the forward function."
|
361 |
+
)
|
362 |
+
context_layer = self._blocksparse_layer(
|
363 |
+
q=q, k=k, v=v, sm_scale=self.softmax_scale
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
"""
|
367 |
+
Shapes of tensors are as follows:
|
368 |
+
q: (bs, nqp, seq_len, hdim)
|
369 |
+
k: (bs, nkp, seq_len, hdim)
|
370 |
+
v: (bs, nkp, seq_len, hdim)
|
371 |
+
We first need to transpose the shapes to fit what the
|
372 |
+
kernel needs, and the reinvert it back at the end of the operations
|
373 |
+
"""
|
374 |
+
assert attention_mask.ndim == 2, "The kernel, like flash-attention-2, only supports 2d attention masks ..."
|
375 |
+
left_paddings = attention_mask.shape[1] - attention_mask.sum(dim=-1)
|
376 |
+
# shape: (bs, seq_len, nqp, hdim)
|
377 |
+
q = q.transpose(1, 2).contiguous()
|
378 |
+
# shape: (bs, seq_len, nkp, hdim)
|
379 |
+
k = k.transpose(1, 2).contiguous()
|
380 |
+
# shape: (bs, seq_len, nkp, hdim)
|
381 |
+
v = v.transpose(1, 2).contiguous()
|
382 |
+
context_layer = self._blocksparse_layer(
|
383 |
+
q=q, k=k, v=v, sm_scale=self.softmax_scale, left_paddings=left_paddings.to(torch.int32)
|
384 |
+
)
|
385 |
+
# shape: (bs, nqp, seq_len, hdim)
|
386 |
+
context_layer = context_layer.transpose(1, 2).contiguous()
|
387 |
+
return context_layer
|
388 |
+
|
389 |
+
def _apply_dense_attention(
|
390 |
+
self,
|
391 |
+
q: torch.Tensor,
|
392 |
+
k: torch.Tensor,
|
393 |
+
v: torch.Tensor,
|
394 |
+
attention_mask: torch.Tensor,
|
395 |
+
return_attention_probs: bool = False,
|
396 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
397 |
+
"""
|
398 |
+
Apply dense attention
|
399 |
+
|
400 |
+
Args:
|
401 |
+
q (torch.Tensor):
|
402 |
+
The query tensor, shape: (bs, num_query_heads, seq_len, head_size)
|
403 |
+
k (torch.Tensor):
|
404 |
+
The key tensor, shape: (bs, num_query_heads, seq_len, head_size)
|
405 |
+
v (torch.Tensor):
|
406 |
+
The value tensor, shape: (bs, num_query_heads, seq_len, head_size)
|
407 |
+
|
408 |
+
return_attention_probs (bool, optional):
|
409 |
+
Return the attention probabilities. Defaults to False.
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
413 |
+
Return the output of the attention aggregation. If `return_attention_probs` is True, then
|
414 |
+
also return the attention probabilities
|
415 |
+
|
416 |
+
.. note::
|
417 |
+
Right now, am assuming the expansion for the query key values is already done
|
418 |
+
outside. But ideally, since Flash attention handles the GQA correctly, we can
|
419 |
+
avoid doing that.
|
420 |
+
|
421 |
+
"""
|
422 |
+
attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
|
423 |
+
# Get into the correct shape for the Flash Attention API
|
424 |
+
# shape: (bs, seq_len, nqp, hn)
|
425 |
+
q = q.transpose(1, 2).contiguous()
|
426 |
+
query_length = q.size(1)
|
427 |
+
# shape: (bs, seq_len, npq, hn)
|
428 |
+
k = k.transpose(1, 2).contiguous()
|
429 |
+
# shape: (bs, seq_len, npq, hn)
|
430 |
+
v = v.transpose(1, 2).contiguous()
|
431 |
+
|
432 |
+
if attention_mask is not None:
|
433 |
+
causal = q.size(2) == k.size(2)
|
434 |
+
batch_size = q.shape[0]
|
435 |
+
flat_q, flat_k, flat_v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
436 |
+
q, k, v, attention_mask, query_length
|
437 |
+
)
|
438 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
439 |
+
max_seqlen_q, max_seqlen_k = max_seq_lens
|
440 |
+
flat_kv = torch.cat((flat_k.unsqueeze(1), flat_v.unsqueeze(1)), dim=1)
|
441 |
+
attn_output_unpad = flash_attn_varlen_kvpacked_func(
|
442 |
+
q=flat_q,
|
443 |
+
kv=flat_kv,
|
444 |
+
cu_seqlens_q=cu_seqlens_q,
|
445 |
+
cu_seqlens_k=cu_seqlens_k,
|
446 |
+
max_seqlen_q=max_seqlen_q,
|
447 |
+
max_seqlen_k=max_seqlen_k,
|
448 |
+
dropout_p=attention_dropout_prob,
|
449 |
+
softmax_scale=self.softmax_scale,
|
450 |
+
causal=causal,
|
451 |
+
return_attn_probs=return_attention_probs
|
452 |
+
)
|
453 |
+
attention_output = pad_input(
|
454 |
+
attn_output_unpad, indices_q, batch_size, query_length
|
455 |
+
)
|
456 |
+
else:
|
457 |
+
kv = torch.cat((k.unsqueeze(2), v.unsqueeze(2)), dim=2)
|
458 |
+
cu_seqlens_q = torch.arange(
|
459 |
+
0, (q.size(0) + 1), device=q.device, dtype=torch.int32
|
460 |
+
) * q.size(1)
|
461 |
+
cu_seqlens_kv = torch.arange(
|
462 |
+
0, (kv.size(0) + 1), device=kv.device, dtype=torch.int32
|
463 |
+
) * kv.size(1)
|
464 |
+
max_seqlen_q = q.size(1)
|
465 |
+
max_seqlen_k = kv.size(1)
|
466 |
+
attention_output = flash_attn_varlen_kvpacked_func(
|
467 |
+
q=collapse_first_n_dims(q, 2),
|
468 |
+
kv=collapse_first_n_dims(kv, 2),
|
469 |
+
cu_seqlens_q=cu_seqlens_q,
|
470 |
+
cu_seqlens_k=cu_seqlens_kv,
|
471 |
+
max_seqlen_q=max_seqlen_q,
|
472 |
+
max_seqlen_k=max_seqlen_k,
|
473 |
+
dropout_p=attention_dropout_prob,
|
474 |
+
softmax_scale=self.softmax_scale,
|
475 |
+
causal=q.size(1) == kv.size(1),
|
476 |
+
return_attn_probs=return_attention_probs
|
477 |
+
)
|
478 |
+
if return_attention_probs:
|
479 |
+
(context_layer, attn_probs) = attention_output
|
480 |
+
context_layer = context_layer.view(q.size(0), q.size(1), -1, q.size(3)).transpose(1, 2).contiguous()
|
481 |
+
return (context_layer, attn_probs)
|
482 |
+
context_layer = attention_output
|
483 |
+
context_layer = context_layer.view(q.size(0), q.size(1), -1, q.size(3)).transpose(1, 2).contiguous()
|
484 |
+
return context_layer
|
485 |
+
|
486 |
+
|
487 |
+
def expand_kv_to_q_size(self, kv: torch.Tensor, num_q_per_kv: int) -> torch.Tensor:
|
488 |
+
"""
|
489 |
+
Expand the key-value tensor to match the size of the query tensor.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
kv (torch.Tensor): The key-value tensor of shape (bsz, nkp, 2, seq_len, hdim).
|
493 |
+
num_q_per_kv (int): The number of queries per key-value.
|
494 |
+
|
495 |
+
Returns:
|
496 |
+
torch.Tensor: The expanded key-value tensor of shape (bsz, nqp, 2, seq_len, hdim).
|
497 |
+
Where nqp = num_q_per_kv * nkp
|
498 |
+
|
499 |
+
.. note(bapatra)::
|
500 |
+
Right now, I am using a repeat_interleave to expand the kv to the size of q.
|
501 |
+
This incurs a memory penalty, since the tensors are actually copied.
|
502 |
+
TODO: If this does yield benefits, then potentially we can use the re-written
|
503 |
+
flash attention kernel that can handle GQA.
|
504 |
+
"""
|
505 |
+
|
506 |
+
repeats = torch.tensor([num_q_per_kv] * kv.size(1)).to(kv.device)
|
507 |
+
total = repeats.sum()
|
508 |
+
expanded_kv = torch.repeat_interleave(
|
509 |
+
kv,
|
510 |
+
repeats=repeats,
|
511 |
+
dim=1,
|
512 |
+
output_size=total
|
513 |
+
)
|
514 |
+
return expanded_kv
|
515 |
+
|
516 |
+
def forward(
|
517 |
+
self,
|
518 |
+
hidden_states: torch.Tensor,
|
519 |
+
attention_mask: Optional[torch.Tensor] = None,
|
520 |
+
position_ids: Optional[torch.LongTensor] = None,
|
521 |
+
past_key_values: Optional[Cache] = None,
|
522 |
+
output_attentions: bool = False,
|
523 |
+
use_cache: bool = False,
|
524 |
+
**kwargs,
|
525 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
526 |
+
"""
|
527 |
+
The forward function of the Self Attention Layer.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
hidden_states (torch.Tensor):
|
531 |
+
The input tensor of shape (bs, q_len, h).
|
532 |
+
attention_mask (Optional[torch.Tensor], optional):
|
533 |
+
The attention mask tensor of shape (bs, seq_len). This is the 2D attention mask tensor as is standard in the flash-attention
|
534 |
+
kernel.
|
535 |
+
Defaults to None.
|
536 |
+
position_ids (Optional[torch.LongTensor], optional):
|
537 |
+
The position ids tensor of shape (bs, q_len). Defaults to None. Unused by the function.
|
538 |
+
past_key_value (Optional[Cache], optional):
|
539 |
+
The previous kv cache values. Defaults to None.
|
540 |
+
output_attentions (bool, optional):
|
541 |
+
Whether to return the attention scores. Defaults to False.
|
542 |
+
.. note::
|
543 |
+
For the blocksparse attention kernel, we do not support returning the attention scores.
|
544 |
+
use_cache (bool, optional):
|
545 |
+
Whether to use the cache for storing the kv. Defaults to False.
|
546 |
+
|
547 |
+
Returns:
|
548 |
+
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
549 |
+
The output tensor of shape (bs, q_len, h),
|
550 |
+
the attention scores tensor of shape (bs, nqp, q_len, seq_len) if `output_attentions` is True,
|
551 |
+
and the updated cache values if `use_cache` is True.
|
552 |
+
|
553 |
+
Notations:
|
554 |
+
------------
|
555 |
+
bs: batch size
|
556 |
+
sq_len: sequence length of the entire sequence
|
557 |
+
q_len: sequence length of the query
|
558 |
+
cache_sq: sequence length in the cache
|
559 |
+
If there is no cache then cache_sq = 0
|
560 |
+
and sq_len = q_len
|
561 |
+
otherwise sq_len = q_len + cache_sq
|
562 |
+
h: hidden size
|
563 |
+
nq: number of query heads
|
564 |
+
nkv: number of key heads
|
565 |
+
hn: hidden size per head
|
566 |
+
hn = h // nq
|
567 |
+
nqp: number of query heads (per MP partition)
|
568 |
+
nqp = nq // (num mp partitions)
|
569 |
+
nkvp: number of key-value heads (per MP partition)
|
570 |
+
nkvp = nk // (num mp partitions)
|
571 |
+
|
572 |
+
"""
|
573 |
+
# shape: (bs, q_len, h)
|
574 |
+
bsz, q_len, _ = hidden_states.size()
|
575 |
+
|
576 |
+
# shape: (bs, q_len, (nqp + 2 * nkvp) * hn)
|
577 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
578 |
+
# shape: (bs, q_len, nqp, hn), shape: (bs, q_len, nkvp, hn), shape: (bs, q_len, nkvp, hn)
|
579 |
+
q, k, v = self._split_heads(mixed_x_layer)
|
580 |
+
|
581 |
+
# shape: (bs, qnp, q_len, hn)
|
582 |
+
query_states = q.permute(0, 2, 1, 3).contiguous()
|
583 |
+
# shape: (bs, nkvp, q_len, hn)
|
584 |
+
key_states = k.permute(0, 2, 1, 3).contiguous()
|
585 |
+
# shape: (bs, nkvp, q_len, hn)
|
586 |
+
value_states = v.permute(0, 2, 1, 3).contiguous()
|
587 |
+
|
588 |
+
kv_seq_len = key_states.shape[-2]
|
589 |
+
if past_key_values is not None:
|
590 |
+
if self.layer_idx is None:
|
591 |
+
raise ValueError(
|
592 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
593 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
594 |
+
"with a layer index."
|
595 |
+
)
|
596 |
+
if self.rotary_emb is not None:
|
597 |
+
seqlen_offset = past_key_values.get_usable_length(kv_seq_len, layer_idx=self.layer_idx)
|
598 |
+
# shape: (bs, nqp, q_len, hn), shape: (bs, nkvp, q_len, hn)
|
599 |
+
query_states, key_states = self.rotary_emb(
|
600 |
+
query_states, key_states, seq_dimension=2, seqlen_offset=seqlen_offset
|
601 |
+
)
|
602 |
+
key_states, value_states = past_key_values.update(key_states=key_states, value_states=value_states, layer_idx=self.layer_idx)
|
603 |
+
else:
|
604 |
+
# In this case seq_len = q_len and cache_sq = 0
|
605 |
+
if self.rotary_emb is not None:
|
606 |
+
# shape: (bs, nqp, seq_len, hn), shape: (bs, nkvp, seq_len, hn)
|
607 |
+
query_states, key_states = self.rotary_emb(query_states, key_states, seq_dimension=2)
|
608 |
+
|
609 |
+
# shape: (bs, nkvp, 2, seq_len, hn)
|
610 |
+
kv_states = torch.cat((key_states.unsqueeze(2), value_states.unsqueeze(2)), dim=2)
|
611 |
+
# shape: (bs, nqp, 2, seq_len, hn)
|
612 |
+
expanded_kv_states = self.expand_kv_to_q_size(kv_states, num_q_per_kv=self.num_q_per_kv)
|
613 |
+
# shape: (bs, nqp, seq_len, hn), shape: (bs, nqp, seq_len, hn)
|
614 |
+
expanded_key_states, expanded_value_states = expanded_kv_states[:, :, 0], expanded_kv_states[:, :, 1]
|
615 |
+
if self.blocksparse:
|
616 |
+
attn_function_output = self._apply_blocksparse_attention(
|
617 |
+
q=query_states,
|
618 |
+
k=expanded_key_states,
|
619 |
+
v=expanded_value_states,
|
620 |
+
attention_mask=attention_mask,
|
621 |
+
return_attention_probs=output_attentions
|
622 |
+
)
|
623 |
+
else:
|
624 |
+
attn_function_output = self._apply_dense_attention(
|
625 |
+
q=query_states,
|
626 |
+
k=expanded_key_states,
|
627 |
+
v=expanded_value_states,
|
628 |
+
attention_mask=attention_mask,
|
629 |
+
return_attention_probs=output_attentions
|
630 |
+
)
|
631 |
+
|
632 |
+
attn_weights = None
|
633 |
+
if output_attentions:
|
634 |
+
attn_output, attn_weights = attn_function_output
|
635 |
+
else:
|
636 |
+
# shape: (bs, nqp, seq_len, hn)
|
637 |
+
attn_output = attn_function_output
|
638 |
+
# shape: (bs, seq_len, nqp, hn)
|
639 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
640 |
+
|
641 |
+
# shape: (bs, seq_len, h)
|
642 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
643 |
+
attn_output = self.dense(attn_output)
|
644 |
+
return attn_output, attn_weights, past_key_values
|
645 |
+
|
646 |
+
|
647 |
+
class Phi3SmallDecoderLayer(nn.Module):
|
648 |
+
def __init__(self, config: Phi3SmallConfig, layer_idx: int):
|
649 |
+
super().__init__()
|
650 |
+
self.hidden_size = config.hidden_size
|
651 |
+
self.self_attn = Phi3SmallSelfAttention(config, layer_idx)
|
652 |
+
self.mlp = Phi3SmallMLP(config)
|
653 |
+
|
654 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
655 |
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
656 |
+
|
657 |
+
def forward(
|
658 |
+
self,
|
659 |
+
hidden_states: torch.Tensor,
|
660 |
+
attention_mask: Optional[torch.Tensor] = None,
|
661 |
+
position_ids: Optional[torch.LongTensor] = None,
|
662 |
+
past_key_values: Optional[Cache] = None,
|
663 |
+
output_attentions: Optional[bool] = None,
|
664 |
+
use_cache: Optional[bool] = None,
|
665 |
+
**kwargs,
|
666 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Cache]]:
|
667 |
+
residual = hidden_states
|
668 |
+
hidden_states = self.input_layernorm(hidden_states)
|
669 |
+
|
670 |
+
# Self Attention
|
671 |
+
hidden_states, self_attn_weights, present_key_values = self.self_attn(
|
672 |
+
hidden_states=hidden_states,
|
673 |
+
attention_mask=attention_mask,
|
674 |
+
position_ids=position_ids,
|
675 |
+
past_key_values=past_key_values,
|
676 |
+
output_attentions=output_attentions,
|
677 |
+
use_cache=use_cache,
|
678 |
+
)
|
679 |
+
hidden_states = residual + hidden_states
|
680 |
+
|
681 |
+
# Fully Connected
|
682 |
+
residual = hidden_states
|
683 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
684 |
+
hidden_states = self.mlp(hidden_states)
|
685 |
+
hidden_states = residual + hidden_states
|
686 |
+
|
687 |
+
outputs = (hidden_states,)
|
688 |
+
|
689 |
+
if output_attentions:
|
690 |
+
outputs += (self_attn_weights,)
|
691 |
+
|
692 |
+
if use_cache:
|
693 |
+
outputs += (present_key_values,)
|
694 |
+
|
695 |
+
return outputs
|
696 |
+
|
697 |
+
|
698 |
+
|
699 |
+
class Phi3SmallPreTrainedModel(PreTrainedModel):
|
700 |
+
config_class = Phi3SmallConfig
|
701 |
+
base_model_prefix = "model"
|
702 |
+
supports_gradient_checkpointing = True
|
703 |
+
_no_split_modules = ["Phi3SmallDecoderLayer"]
|
704 |
+
skip_keys_device_placement = "past_key_values"
|
705 |
+
_supports_flash_attn_2 = True
|
706 |
+
_supports_sdpa = False
|
707 |
+
_supports_cache_class = True
|
708 |
+
|
709 |
+
def _init_weights(self, module: nn.Module):
|
710 |
+
std = self.config.initializer_range
|
711 |
+
if isinstance(module, nn.Linear):
|
712 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
713 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
714 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
715 |
+
elif isinstance(module, nn.Embedding):
|
716 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
717 |
+
if module.padding_idx is not None:
|
718 |
+
module.weight.data[module.padding_idx].zero_()
|
719 |
+
elif isinstance(module, nn.LayerNorm):
|
720 |
+
module.bias.data.zero_()
|
721 |
+
module.weight.data.fill_(1.0)
|
722 |
+
|
723 |
+
# The output projection on the decoder attention layer as well as the down_proj in the MLP are scaled
|
724 |
+
# differently (dubbed `output_layer_init_method` in the Megatron code). This is replicated here
|
725 |
+
for name, p in module.named_parameters():
|
726 |
+
if any(x in name for x in ("c_proj.weight", "down_proj.weight", "o_proj.weight")):
|
727 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
728 |
+
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers)))
|
729 |
+
|
730 |
+
|
731 |
+
class Phi3SmallModel(Phi3SmallPreTrainedModel):
|
732 |
+
|
733 |
+
def __init__(self, config):
|
734 |
+
super().__init__(config)
|
735 |
+
self.config = config
|
736 |
+
|
737 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
738 |
+
|
739 |
+
# Embedding Dropout
|
740 |
+
self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob)
|
741 |
+
|
742 |
+
# MuP Embedding scaling
|
743 |
+
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
744 |
+
|
745 |
+
self.layers = nn.ModuleList([Phi3SmallDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
746 |
+
|
747 |
+
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
748 |
+
|
749 |
+
self.gradient_checkpointing = False
|
750 |
+
|
751 |
+
# Initialize weights and apply final processing
|
752 |
+
self.post_init()
|
753 |
+
|
754 |
+
def get_input_embeddings(self):
|
755 |
+
return self.embed_tokens
|
756 |
+
|
757 |
+
def set_input_embeddings(self, value):
|
758 |
+
self.embed_tokens = value
|
759 |
+
|
760 |
+
@property
|
761 |
+
def pad_sequence_to_multiple_of_64(self):
|
762 |
+
# We only need to do this for the backward pass. So only required
|
763 |
+
# when we are in the context of generating gradients
|
764 |
+
return self.config.pad_sequence_to_multiple_of_64 and torch.is_grad_enabled()
|
765 |
+
|
766 |
+
def forward(
|
767 |
+
self,
|
768 |
+
input_ids: torch.LongTensor = None,
|
769 |
+
attention_mask: Optional[torch.Tensor] = None,
|
770 |
+
position_ids: Optional[torch.LongTensor] = None,
|
771 |
+
past_key_values: Optional[Union[Cache, LegacyCache]] = None,
|
772 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
773 |
+
use_cache: Optional[bool] = None,
|
774 |
+
output_attentions: Optional[bool] = None,
|
775 |
+
output_hidden_states: Optional[bool] = None,
|
776 |
+
return_dict: Optional[bool] = None,
|
777 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
778 |
+
|
779 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
780 |
+
output_hidden_states = (
|
781 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
782 |
+
)
|
783 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
784 |
+
|
785 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
786 |
+
|
787 |
+
if input_ids is not None and inputs_embeds is not None:
|
788 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
789 |
+
elif input_ids is not None:
|
790 |
+
batch_size, seq_length = input_ids.shape
|
791 |
+
elif inputs_embeds is not None:
|
792 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
793 |
+
else:
|
794 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
795 |
+
|
796 |
+
if self.gradient_checkpointing and self.training:
|
797 |
+
if use_cache:
|
798 |
+
logger.warning_once(
|
799 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
800 |
+
)
|
801 |
+
use_cache = False
|
802 |
+
|
803 |
+
past_key_values_length = 0
|
804 |
+
|
805 |
+
if use_cache:
|
806 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
807 |
+
if use_legacy_cache:
|
808 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
809 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
810 |
+
|
811 |
+
if position_ids is None:
|
812 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
813 |
+
position_ids = torch.arange(
|
814 |
+
past_key_values_length, past_key_values_length + seq_length, dtype=torch.long, device=device
|
815 |
+
)
|
816 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
817 |
+
else:
|
818 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
819 |
+
|
820 |
+
if attention_mask is not None:
|
821 |
+
if batch_size <= 0:
|
822 |
+
raise ValueError("batch_size has to be defined and > 0")
|
823 |
+
|
824 |
+
if inputs_embeds is None:
|
825 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
826 |
+
inputs_embeds = self.embedding_dropout(inputs_embeds)
|
827 |
+
|
828 |
+
if self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0:
|
829 |
+
inputs_embeds = inputs_embeds * self.mup_embedding_multiplier
|
830 |
+
|
831 |
+
residual = 0
|
832 |
+
if self.pad_sequence_to_multiple_of_64:
|
833 |
+
# note(bapatra): Since we don't particularly use the position_ids and the attention mask
|
834 |
+
# we don't need to pad them
|
835 |
+
inputs_embeds, residual = pad_tensor_to_next_mult_of(tensor=inputs_embeds, dim=1, n=64)
|
836 |
+
|
837 |
+
hidden_states = inputs_embeds
|
838 |
+
|
839 |
+
# decoder layers
|
840 |
+
all_hidden_states = () if output_hidden_states else None
|
841 |
+
all_self_attns = () if output_attentions else None
|
842 |
+
next_decoder_cache = None
|
843 |
+
|
844 |
+
for decoder_layer in self.layers:
|
845 |
+
if output_hidden_states:
|
846 |
+
all_hidden_states += (hidden_states,)
|
847 |
+
|
848 |
+
if self.gradient_checkpointing and self.training:
|
849 |
+
layer_outputs = self._gradient_checkpointing_func(
|
850 |
+
decoder_layer.__call__,
|
851 |
+
hidden_states,
|
852 |
+
attention_mask,
|
853 |
+
position_ids,
|
854 |
+
past_key_values,
|
855 |
+
output_attentions,
|
856 |
+
use_cache,
|
857 |
+
)
|
858 |
+
else:
|
859 |
+
layer_outputs = decoder_layer(
|
860 |
+
hidden_states,
|
861 |
+
attention_mask=attention_mask,
|
862 |
+
position_ids=position_ids,
|
863 |
+
past_key_values=past_key_values,
|
864 |
+
output_attentions=output_attentions,
|
865 |
+
use_cache=use_cache,
|
866 |
+
)
|
867 |
+
hidden_states = layer_outputs[0]
|
868 |
+
|
869 |
+
if use_cache:
|
870 |
+
# Following the Mistral schema for layer return values
|
871 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
872 |
+
if output_attentions:
|
873 |
+
all_self_attns += (layer_outputs[1],)
|
874 |
+
|
875 |
+
hidden_states = self.final_layernorm(hidden_states)
|
876 |
+
|
877 |
+
if residual > 0:
|
878 |
+
hidden_states = strip_padding_from_tensor(tensor=hidden_states, dim=1, residual=residual)
|
879 |
+
|
880 |
+
# add hidden states from the last decoder layer
|
881 |
+
if output_hidden_states:
|
882 |
+
all_hidden_states += (hidden_states,)
|
883 |
+
|
884 |
+
next_cache = None
|
885 |
+
if use_cache:
|
886 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
887 |
+
|
888 |
+
if not return_dict:
|
889 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
890 |
+
return BaseModelOutputWithPast(
|
891 |
+
last_hidden_state=hidden_states,
|
892 |
+
past_key_values=next_cache,
|
893 |
+
hidden_states=all_hidden_states,
|
894 |
+
attentions=all_self_attns,
|
895 |
+
)
|
896 |
+
|
897 |
+
|
898 |
+
class Phi3SmallForCausalLM(Phi3SmallPreTrainedModel):
|
899 |
+
_tied_weights_keys = ["lm_head.weight"]
|
900 |
+
|
901 |
+
def __init__(self, config):
|
902 |
+
super().__init__(config)
|
903 |
+
self.model = Phi3SmallModel(config)
|
904 |
+
self.vocab_size = config.vocab_size
|
905 |
+
self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
|
906 |
+
self.mup_width_multiplier = config.mup_width_multiplier
|
907 |
+
|
908 |
+
# Create the mask for the dummy tokens in the vocabulary
|
909 |
+
dummy_token_indices = config.dummy_token_indices
|
910 |
+
dummy_tokens_mask = torch.zeros(self.vocab_size).bool()
|
911 |
+
dummy_tokens_mask[dummy_token_indices] = True
|
912 |
+
# shape: (vocab_size,)
|
913 |
+
self.register_buffer("dummy_tokens_mask", dummy_tokens_mask, persistent=False)
|
914 |
+
|
915 |
+
# Initialize weights and apply final processing
|
916 |
+
self.post_init()
|
917 |
+
|
918 |
+
def get_input_embeddings(self):
|
919 |
+
return self.model.embed_tokens
|
920 |
+
|
921 |
+
def set_input_embeddings(self, value):
|
922 |
+
self.model.embed_tokens = value
|
923 |
+
|
924 |
+
def get_output_embeddings(self):
|
925 |
+
return self.lm_head
|
926 |
+
|
927 |
+
def set_output_embeddings(self, value):
|
928 |
+
self.lm_head = value
|
929 |
+
|
930 |
+
def set_decoder(self, decoder):
|
931 |
+
self.model = decoder
|
932 |
+
|
933 |
+
def get_decoder(self):
|
934 |
+
return self.model
|
935 |
+
|
936 |
+
def forward(
|
937 |
+
self,
|
938 |
+
input_ids: torch.LongTensor = None,
|
939 |
+
attention_mask: Optional[torch.Tensor] = None,
|
940 |
+
position_ids: Optional[torch.LongTensor] = None,
|
941 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
942 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
943 |
+
labels: Optional[torch.LongTensor] = None,
|
944 |
+
use_cache: Optional[bool] = None,
|
945 |
+
output_attentions: Optional[bool] = None,
|
946 |
+
output_hidden_states: Optional[bool] = None,
|
947 |
+
return_dict: Optional[bool] = None,
|
948 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
949 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
950 |
+
output_hidden_states = (
|
951 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
952 |
+
)
|
953 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
954 |
+
|
955 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
956 |
+
outputs = self.model(
|
957 |
+
input_ids=input_ids,
|
958 |
+
attention_mask=attention_mask,
|
959 |
+
position_ids=position_ids,
|
960 |
+
past_key_values=past_key_values,
|
961 |
+
inputs_embeds=inputs_embeds,
|
962 |
+
use_cache=use_cache,
|
963 |
+
output_attentions=output_attentions,
|
964 |
+
output_hidden_states=output_hidden_states,
|
965 |
+
return_dict=return_dict,
|
966 |
+
)
|
967 |
+
|
968 |
+
hidden_states = outputs[0]
|
969 |
+
logits = self.lm_head(hidden_states)
|
970 |
+
logits = logits.float()
|
971 |
+
if self.mup_width_multiplier:
|
972 |
+
logits = logits / self.mup_width_multiplier
|
973 |
+
logits = logits.masked_fill(self.dummy_tokens_mask, min_value_of_dtype(logits.dtype))
|
974 |
+
|
975 |
+
loss = None
|
976 |
+
if labels is not None:
|
977 |
+
# Shift so that tokens < n predict n
|
978 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
979 |
+
shift_labels = labels[..., 1:].contiguous()
|
980 |
+
# Flatten the tokens
|
981 |
+
loss_fct = nn.CrossEntropyLoss()
|
982 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
983 |
+
shift_labels = shift_labels.view(-1)
|
984 |
+
# Enable model parallelism
|
985 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
986 |
+
loss = loss_fct(shift_logits, shift_labels)
|
987 |
+
|
988 |
+
if not return_dict:
|
989 |
+
output = (logits,) + outputs[1:]
|
990 |
+
return (loss,) + output if loss is not None else output
|
991 |
+
|
992 |
+
return CausalLMOutputWithPast(
|
993 |
+
loss=loss,
|
994 |
+
logits=logits,
|
995 |
+
past_key_values=outputs.past_key_values,
|
996 |
+
hidden_states=outputs.hidden_states,
|
997 |
+
attentions=outputs.attentions,
|
998 |
+
)
|
999 |
+
|
1000 |
+
def prepare_inputs_for_generation(
|
1001 |
+
self,
|
1002 |
+
input_ids: torch.LongTensor,
|
1003 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1004 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1005 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1006 |
+
**kwargs
|
1007 |
+
) -> Dict[str, Any]:
|
1008 |
+
# only last token for inputs_ids if past is defined in kwargs
|
1009 |
+
if past_key_values:
|
1010 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1011 |
+
|
1012 |
+
position_ids = kwargs.get("position_ids", None)
|
1013 |
+
|
1014 |
+
if attention_mask is not None and position_ids is None:
|
1015 |
+
# create position_ids on the fly for batch generation
|
1016 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1017 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1018 |
+
if past_key_values:
|
1019 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1020 |
+
else:
|
1021 |
+
position_ids = None
|
1022 |
+
|
1023 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1024 |
+
if inputs_embeds is not None and past_key_values is None:
|
1025 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1026 |
+
else:
|
1027 |
+
model_inputs = {"input_ids": input_ids}
|
1028 |
+
|
1029 |
+
model_inputs.update(
|
1030 |
+
{
|
1031 |
+
"past_key_values": past_key_values,
|
1032 |
+
"use_cache": kwargs.get("use_cache"),
|
1033 |
+
"position_ids": position_ids,
|
1034 |
+
"attention_mask": attention_mask,
|
1035 |
+
}
|
1036 |
+
)
|
1037 |
+
return model_inputs
|
1038 |
+
|
1039 |
+
|
1040 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with Mistral -> Phi3Small
|
1041 |
+
class Phi3SmallForSequenceClassification(Phi3SmallPreTrainedModel):
|
1042 |
+
def __init__(self, config):
|
1043 |
+
super().__init__(config)
|
1044 |
+
self.num_labels = config.num_labels
|
1045 |
+
self.model = Phi3SmallModel(config)
|
1046 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1047 |
+
|
1048 |
+
# Initialize weights and apply final processing
|
1049 |
+
self.post_init()
|
1050 |
+
|
1051 |
+
def get_input_embeddings(self):
|
1052 |
+
return self.model.embed_tokens
|
1053 |
+
|
1054 |
+
def set_input_embeddings(self, value):
|
1055 |
+
self.model.embed_tokens = value
|
1056 |
+
|
1057 |
+
|
1058 |
+
def forward(
|
1059 |
+
self,
|
1060 |
+
input_ids: torch.LongTensor = None,
|
1061 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1062 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1063 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1064 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1065 |
+
labels: Optional[torch.LongTensor] = None,
|
1066 |
+
use_cache: Optional[bool] = None,
|
1067 |
+
output_attentions: Optional[bool] = None,
|
1068 |
+
output_hidden_states: Optional[bool] = None,
|
1069 |
+
return_dict: Optional[bool] = None,
|
1070 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1071 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1072 |
+
|
1073 |
+
transformer_outputs = self.model(
|
1074 |
+
input_ids,
|
1075 |
+
attention_mask=attention_mask,
|
1076 |
+
position_ids=position_ids,
|
1077 |
+
past_key_values=past_key_values,
|
1078 |
+
inputs_embeds=inputs_embeds,
|
1079 |
+
use_cache=use_cache,
|
1080 |
+
output_attentions=output_attentions,
|
1081 |
+
output_hidden_states=output_hidden_states,
|
1082 |
+
return_dict=return_dict,
|
1083 |
+
)
|
1084 |
+
hidden_states = transformer_outputs[0]
|
1085 |
+
logits = self.score(hidden_states)
|
1086 |
+
|
1087 |
+
if input_ids is not None:
|
1088 |
+
batch_size = input_ids.shape[0]
|
1089 |
+
else:
|
1090 |
+
batch_size = inputs_embeds.shape[0]
|
1091 |
+
|
1092 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1093 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1094 |
+
if self.config.pad_token_id is None:
|
1095 |
+
sequence_lengths = -1
|
1096 |
+
else:
|
1097 |
+
if input_ids is not None:
|
1098 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1099 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1100 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1101 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1102 |
+
else:
|
1103 |
+
sequence_lengths = -1
|
1104 |
+
|
1105 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1106 |
+
|
1107 |
+
loss = None
|
1108 |
+
if labels is not None:
|
1109 |
+
labels = labels.to(logits.device)
|
1110 |
+
if self.config.problem_type is None:
|
1111 |
+
if self.num_labels == 1:
|
1112 |
+
self.config.problem_type = "regression"
|
1113 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1114 |
+
self.config.problem_type = "single_label_classification"
|
1115 |
+
else:
|
1116 |
+
self.config.problem_type = "multi_label_classification"
|
1117 |
+
|
1118 |
+
if self.config.problem_type == "regression":
|
1119 |
+
loss_fct = nn.MSELoss()
|
1120 |
+
if self.num_labels == 1:
|
1121 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1122 |
+
else:
|
1123 |
+
loss = loss_fct(pooled_logits, labels)
|
1124 |
+
elif self.config.problem_type == "single_label_classification":
|
1125 |
+
loss_fct = nn.CrossEntropyLoss()
|
1126 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1127 |
+
elif self.config.problem_type == "multi_label_classification":
|
1128 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
1129 |
+
loss = loss_fct(pooled_logits, labels)
|
1130 |
+
if not return_dict:
|
1131 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1132 |
+
return ((loss,) + output) if loss is not None else output
|
1133 |
+
|
1134 |
+
return SequenceClassifierOutputWithPast(
|
1135 |
+
loss=loss,
|
1136 |
+
logits=pooled_logits,
|
1137 |
+
past_key_values=transformer_outputs.past_key_values,
|
1138 |
+
hidden_states=transformer_outputs.hidden_states,
|
1139 |
+
attentions=transformer_outputs.attentions,
|
1140 |
+
)
|
positional_embedding.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Orginally Taken verbatim from xformers library
|
3 |
+
https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py
|
4 |
+
|
5 |
+
The difference is that xformers seems to assume the inputs to be
|
6 |
+
(bs, head, seq_len, dim) while we assume (bs, seq_len, head, dim)
|
7 |
+
|
8 |
+
"""
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
10 |
+
#
|
11 |
+
# This source code is licensed under the BSD license found in the
|
12 |
+
# LICENSE file in the root directory of this source tree.
|
13 |
+
|
14 |
+
|
15 |
+
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
|
16 |
+
# NOTE: Almost the same right now, moving parts to Triton is the next step
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import List, Optional, Tuple, Dict, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import dataclasses
|
23 |
+
from transformers.utils import logging
|
24 |
+
|
25 |
+
from transformers import PretrainedConfig
|
26 |
+
|
27 |
+
is_dacite_available = False
|
28 |
+
try:
|
29 |
+
import dacite
|
30 |
+
is_dacite_available = True
|
31 |
+
except ImportError:
|
32 |
+
pass
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__)
|
35 |
+
|
36 |
+
@dataclasses.dataclass
|
37 |
+
class LongRopeConfig(object):
|
38 |
+
short_factor: List[float]
|
39 |
+
long_factor: List[float]
|
40 |
+
original_max_position_embeddings: int
|
41 |
+
type: str = "longrope"
|
42 |
+
short_mscale: float = -1
|
43 |
+
long_mscale: float = -1
|
44 |
+
|
45 |
+
|
46 |
+
def __post_init__(self):
|
47 |
+
assert self.type in ("longrope", "su"), f"Invalid type {self.type} for LongRopeConfig. Expected longrope / su"
|
48 |
+
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def from_dict(cls, config_dict: Dict[str, Union[float, List[float], int]]) -> "LongRopeConfig":
|
52 |
+
if is_dacite_available:
|
53 |
+
# Preferred since we can also type check the input
|
54 |
+
return dacite.from_dict(data_class=cls, data=config_dict)
|
55 |
+
kwargs = {}
|
56 |
+
for field in dataclasses.fields(cls):
|
57 |
+
if field.name in config_dict:
|
58 |
+
if field.init:
|
59 |
+
kwargs[field.name] = config_dict[field.name]
|
60 |
+
else:
|
61 |
+
raise ValueError(f"Field {field.name} is not initiable")
|
62 |
+
else:
|
63 |
+
if field.default is dataclasses.MISSING:
|
64 |
+
raise ValueError(f"Field {field.name} is required")
|
65 |
+
extra_keys = set(config_dict.keys()) - set(kwargs.keys())
|
66 |
+
if len(extra_keys) > 0:
|
67 |
+
for key in extra_keys:
|
68 |
+
logger.error(f"Unrecognized key {key} in config_dict")
|
69 |
+
raise ValueError(f"Unrecognized keys in config_dict")
|
70 |
+
return cls(**kwargs)
|
71 |
+
|
72 |
+
def rotate_half(x):
|
73 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
74 |
+
return torch.cat((-x2, x1), dim=x1.ndim - 1)
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
@torch.jit.script
|
79 |
+
def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int):
|
80 |
+
# NOTE: This could probably be moved to Triton
|
81 |
+
|
82 |
+
if seq_dimension == 0:
|
83 |
+
cos = cos[: x.shape[0], None, None, :]
|
84 |
+
sin = sin[: x.shape[0], None, None, :]
|
85 |
+
elif seq_dimension == 1:
|
86 |
+
# Handle a possible sequence length mismatch in between q and k
|
87 |
+
cos = cos[None, : x.shape[1], None, :]
|
88 |
+
sin = sin[None, : x.shape[1], None, :]
|
89 |
+
elif seq_dimension == 2:
|
90 |
+
cos = cos[None, None, : x.shape[2], :]
|
91 |
+
sin = sin[None, None, : x.shape[2], :]
|
92 |
+
|
93 |
+
return (x * cos) + (rotate_half(x) * sin)
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
class RotaryEmbedding(torch.nn.Module):
|
98 |
+
"""
|
99 |
+
Adapted from the xformers library
|
100 |
+
|
101 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
102 |
+
A crucial insight from the method is that the query and keys are
|
103 |
+
transformed by rotation matrices which depend on the relative positions.
|
104 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
105 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
106 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
107 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
108 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
109 |
+
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
|
110 |
+
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
|
111 |
+
|
112 |
+
# Arguments
|
113 |
+
:param dim_mode: head dimention
|
114 |
+
:param max_seq_len:
|
115 |
+
:param default_seq_dimension: which dim is the sequence length
|
116 |
+
:param dtype: cos/sin dtype
|
117 |
+
:param use_fused_kernel: if to use customized fused kernel.
|
118 |
+
Note: if used, q, k will be modified inplace. Ok for both forward & backward.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
dim_model: int,
|
124 |
+
*,
|
125 |
+
max_seq_len: Optional[int] = None,
|
126 |
+
dtype: Optional[torch.dtype] = None,
|
127 |
+
base=10000,
|
128 |
+
position_scale=1,
|
129 |
+
device: Optional[torch.device] = None,
|
130 |
+
longrope_config: Optional[LongRopeConfig] = None,
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
self.base = base
|
134 |
+
self.dim_model = dim_model
|
135 |
+
self.max_seq_len = max_seq_len
|
136 |
+
self.longrope_config = longrope_config
|
137 |
+
|
138 |
+
if self.is_longrope:
|
139 |
+
# Keep the maximum range vector, and slice from it as needed
|
140 |
+
self.register_buffer(
|
141 |
+
"range_vector",
|
142 |
+
torch.arange(max_seq_len, device=device, dtype=torch.float32),
|
143 |
+
persistent=False
|
144 |
+
)
|
145 |
+
self.register_buffer(
|
146 |
+
"short_factors",
|
147 |
+
torch.tensor(self.longrope_config.short_factor, dtype=torch.float32),
|
148 |
+
persistent=False
|
149 |
+
)
|
150 |
+
self.register_buffer(
|
151 |
+
"long_factors",
|
152 |
+
torch.tensor(self.longrope_config.long_factor, dtype=torch.float32),
|
153 |
+
persistent=False
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
157 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim_model, 2).float().to(device) / self.dim_model))
|
158 |
+
self.register_buffer("inv_freq", inv_freq)
|
159 |
+
|
160 |
+
self.position_scale = position_scale
|
161 |
+
|
162 |
+
if not self.is_longrope:
|
163 |
+
dtype = dtype or torch.get_default_dtype()
|
164 |
+
self._set_cos_sin_cache(
|
165 |
+
seq_len=max_seq_len,
|
166 |
+
device=self.inv_freq.device,
|
167 |
+
dtype=dtype,
|
168 |
+
)
|
169 |
+
@property
|
170 |
+
def is_longrope(self):
|
171 |
+
return self.longrope_config is not None
|
172 |
+
|
173 |
+
@property
|
174 |
+
def original_max_seq_len(self):
|
175 |
+
if self.longrope_config is not None:
|
176 |
+
return self.longrope_config.original_max_position_embeddings
|
177 |
+
logger.warning_once(
|
178 |
+
(
|
179 |
+
"``original_max_seq_len'' is being accessed, but longrope_config has not been set. "
|
180 |
+
"Please only do this if you are sure about the context."
|
181 |
+
)
|
182 |
+
)
|
183 |
+
return self.max_seq_len
|
184 |
+
|
185 |
+
def get_range_vector(self, seq_len: int, device: torch.device):
|
186 |
+
if self.is_longrope:
|
187 |
+
assert seq_len < self.range_vector.shape[0], f"Found seq_len {seq_len} greater than max_seq_len {self.range_vector.shape[0]}"
|
188 |
+
if self.range_vector.device != device:
|
189 |
+
self.range_vector = self.range_vector.to(device)
|
190 |
+
return self.range_vector[:seq_len]
|
191 |
+
return torch.arange(seq_len, device=device, dtype=torch.float32)
|
192 |
+
|
193 |
+
|
194 |
+
def _calc_mscale(self, scale: torch.Tensor) -> torch.Tensor:
|
195 |
+
if scale <= 1.0:
|
196 |
+
return 1.0
|
197 |
+
return math.sqrt(1 + math.log(scale) / math.log(self.original_max_seq_len))
|
198 |
+
|
199 |
+
def _set_cos_sin_cache(
|
200 |
+
self,
|
201 |
+
seq_len: int,
|
202 |
+
device: Optional[torch.device] = None,
|
203 |
+
dtype: Optional[torch.dtype] = None,
|
204 |
+
) -> None:
|
205 |
+
dtype = dtype or torch.get_default_dtype()
|
206 |
+
self.max_seq_len_cached = seq_len
|
207 |
+
t = (torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) * self.position_scale).type_as(self.inv_freq)
|
208 |
+
device_type = device.type if device is not None else "cpu"
|
209 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
210 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
211 |
+
# shape: (seq_len, dim_model // 2)
|
212 |
+
freqs = torch.outer(t, self.inv_freq)
|
213 |
+
# shape: (seq_len, dim_model)
|
214 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
215 |
+
cos = emb.cos()
|
216 |
+
sin = emb.sin()
|
217 |
+
self.register_buffer("cos_cached", cos.to(dtype), persistent=False)
|
218 |
+
self.register_buffer("sin_cached", sin.to(dtype), persistent=False)
|
219 |
+
|
220 |
+
def forward(
|
221 |
+
self, q: torch.Tensor,
|
222 |
+
k: torch.Tensor,
|
223 |
+
seq_dimension: int = 1,
|
224 |
+
seqlen_offset: int = 0,
|
225 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
226 |
+
"""q, k does not include `seqlen_offset`
|
227 |
+
q: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
|
228 |
+
k: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
|
229 |
+
"""
|
230 |
+
if seq_dimension < 0:
|
231 |
+
seq_dimension = k.ndim + seq_dimension
|
232 |
+
assert seq_dimension in (0, 1, 2)
|
233 |
+
seq_len = k.shape[seq_dimension] + seqlen_offset
|
234 |
+
|
235 |
+
if self.is_longrope:
|
236 |
+
if seq_len > self.original_max_seq_len:
|
237 |
+
t = self.get_range_vector(seq_len, device=q.device)
|
238 |
+
rescale_factors = self.long_factors.to(q.device)
|
239 |
+
long_mscale = self.longrope_config.long_mscale
|
240 |
+
mscale = long_mscale if long_mscale > 0 else self._calc_mscale(self.max_seq_len / self.original_max_seq_len)
|
241 |
+
else:
|
242 |
+
t = self.get_range_vector(self.original_max_seq_len, device=q.device)
|
243 |
+
rescale_factors = self.short_factors.to(q.device)
|
244 |
+
short_mscale = self.longrope_config.short_mscale
|
245 |
+
mscale = short_mscale if short_mscale > 0 else 1.0
|
246 |
+
assert rescale_factors.shape == (self.dim_model // 2, ), (
|
247 |
+
f"misaligned shape for LongRoPE rescale factors:\n"
|
248 |
+
f"\tExpected {(self.dim_model // 2, )}, got {rescale_factors.shape}."
|
249 |
+
)
|
250 |
+
inv_freq = 1.0 / (rescale_factors * (self.base ** (torch.arange(0, self.dim_model, 2).float().to(q.device) / self.dim_model)))
|
251 |
+
device_type = q.device.type if q.device is not None else "cpu"
|
252 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
253 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
254 |
+
freqs = torch.outer(t, inv_freq)
|
255 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
256 |
+
cos = emb.cos() * mscale
|
257 |
+
sin = emb.sin() * mscale
|
258 |
+
cos_cached = cos.to(q.dtype)
|
259 |
+
sin_cached = sin.to(q.dtype)
|
260 |
+
else:
|
261 |
+
if seq_len > self.max_seq_len_cached:
|
262 |
+
self._set_cos_sin_cache(
|
263 |
+
seq_len=seq_len,
|
264 |
+
device=k.device,
|
265 |
+
dtype=k.dtype,
|
266 |
+
)
|
267 |
+
cos_cached = self.cos_cached
|
268 |
+
sin_cached = self.sin_cached
|
269 |
+
return (
|
270 |
+
apply_rotary_pos_emb(
|
271 |
+
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
272 |
+
).to(q.dtype),
|
273 |
+
apply_rotary_pos_emb(
|
274 |
+
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
275 |
+
).to(k.dtype),
|
276 |
+
)
|
277 |
+
|
278 |
+
@classmethod
|
279 |
+
def from_config(cls, config: PretrainedConfig) -> "RotaryEmbedding":
|
280 |
+
kwargs = dict(
|
281 |
+
dim_model=config.hidden_size // config.num_attention_heads,
|
282 |
+
max_seq_len=config.max_position_embeddings,
|
283 |
+
base=config.rope_embedding_base,
|
284 |
+
position_scale=config.rope_position_scale,
|
285 |
+
)
|
286 |
+
if config.rope_scaling is not None:
|
287 |
+
kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
|
288 |
+
return cls(**kwargs)
|
smash_config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"api_key": null,
|
3 |
+
"verify_url": "http://johnrachwan.pythonanywhere.com",
|
4 |
+
"smash_config": {
|
5 |
+
"pruners": "None",
|
6 |
+
"pruning_ratio": 0.0,
|
7 |
+
"factorizers": "None",
|
8 |
+
"quantizers": "['llm-int8']",
|
9 |
+
"weight_quantization_bits": 4,
|
10 |
+
"output_deviation": 0.005,
|
11 |
+
"compilers": "None",
|
12 |
+
"static_batch": true,
|
13 |
+
"static_shape": true,
|
14 |
+
"controlnet": "None",
|
15 |
+
"unet_dim": 4,
|
16 |
+
"device": "cuda",
|
17 |
+
"cache_dir": "/ceph/hdd/staff/charpent/.cache/modelsjayrsbxc",
|
18 |
+
"batch_size": 1,
|
19 |
+
"model_name": "numind/NuExtract-large",
|
20 |
+
"task": "text_text_generation",
|
21 |
+
"max_batch_size": 1,
|
22 |
+
"qtype_weight": "torch.qint8",
|
23 |
+
"qtype_activation": "torch.quint8",
|
24 |
+
"qobserver": "<class 'torch.ao.quantization.observer.MinMaxObserver'>",
|
25 |
+
"qscheme": "torch.per_tensor_symmetric",
|
26 |
+
"qconfig": "x86",
|
27 |
+
"group_size": 128,
|
28 |
+
"damp_percent": 0.1,
|
29 |
+
"save_load_fn": "bitsandbytes"
|
30 |
+
}
|
31 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<|endoftext|>",
|
3 |
+
"eos_token": "<|endoftext|>",
|
4 |
+
"pad_token": "<|endoftext|>"
|
5 |
+
}
|
tokenization_phi3_small.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/tokenization_qwen.py
|
2 |
+
import os
|
3 |
+
from typing import Collection, List, Optional, Dict, Set, Tuple, Union
|
4 |
+
|
5 |
+
from functools import cached_property
|
6 |
+
|
7 |
+
import base64
|
8 |
+
|
9 |
+
from transformers import PreTrainedTokenizer, AddedToken, AutoConfig
|
10 |
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
11 |
+
import tiktoken
|
12 |
+
|
13 |
+
|
14 |
+
"""
|
15 |
+
This tokenizer is almost identical to tiktoken.get_encoding("cl100k_base")
|
16 |
+
with a few additional special tokens to support the ChatML format.
|
17 |
+
|
18 |
+
TODO(bapatra): Right now, I do not save the special tokens to the vocab file.
|
19 |
+
Maybe in the future, that would be useful? Can add that support later.
|
20 |
+
|
21 |
+
"""
|
22 |
+
|
23 |
+
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
24 |
+
with open(tiktoken_bpe_file, "rb") as f:
|
25 |
+
contents = f.read()
|
26 |
+
return {
|
27 |
+
base64.b64decode(token): int(rank)
|
28 |
+
for token, rank in (line.split() for line in contents.splitlines() if line)
|
29 |
+
}
|
30 |
+
|
31 |
+
# On the megatron codebase, we pad vocabularies to ensure matrix multiplication is fast.
|
32 |
+
# this in turn causes some indices to be empty. We account for these empty indices by adding
|
33 |
+
# dummy tokens to the tokenizer.
|
34 |
+
|
35 |
+
EFFECTIVE_PADDED_VOCAB_SIZE = 100352
|
36 |
+
ACTUAL_VOCAB_SIZE = 100276
|
37 |
+
|
38 |
+
|
39 |
+
DUMMY_TOKENS = {
|
40 |
+
f"<|dummy_id_{11 + offset}|>": 100276 + offset
|
41 |
+
for offset in range(1, EFFECTIVE_PADDED_VOCAB_SIZE - ACTUAL_VOCAB_SIZE)
|
42 |
+
}
|
43 |
+
|
44 |
+
SPECIAL_TOKENS = {
|
45 |
+
# tiktoken.get_encoding("cl100k_base")._special_tokens
|
46 |
+
'<|endoftext|>': 100257,
|
47 |
+
'<|fim_prefix|>': 100258,
|
48 |
+
'<|fim_middle|>': 100259,
|
49 |
+
'<|fim_suffix|>': 100260,
|
50 |
+
# Special tokens for post-training
|
51 |
+
"<|system|>": 100261,
|
52 |
+
"<|user|>": 100262,
|
53 |
+
"<|assistant|>": 100263,
|
54 |
+
# Dummy unused tokens
|
55 |
+
"<|dummy_id_0|>": 100264,
|
56 |
+
"<|dummy_id_1|>": 100265,
|
57 |
+
# Special tokens for post-training continued
|
58 |
+
"<|end|>": 100266,
|
59 |
+
# Some dummy tokens, so that tokenization is contiguous and does not cause issues
|
60 |
+
# Note that the 100256th token of tiktoken.get_encoding("cl100k_base") does not
|
61 |
+
# actually map to anything. So we use a dummy token here.
|
62 |
+
"<|dummy_id_2|>": 100256,
|
63 |
+
# Likewise, tokens from 100267 to 100275 are also unused
|
64 |
+
"<|dummy_id_3|>": 100267,
|
65 |
+
"<|dummy_id_4|>": 100268,
|
66 |
+
"<|dummy_id_5|>": 100269,
|
67 |
+
"<|dummy_id_6|>": 100270,
|
68 |
+
"<|dummy_id_7|>": 100271,
|
69 |
+
"<|dummy_id_8|>": 100272,
|
70 |
+
"<|dummy_id_9|>": 100273,
|
71 |
+
"<|dummy_id_10|>": 100274,
|
72 |
+
"<|dummy_id_11|>": 100275,
|
73 |
+
# The final end of prompt token
|
74 |
+
# (unused, but present as a part of tiktoken.get_encoding("cl100k_base")._special_tokens)
|
75 |
+
'<|endofprompt|>': 100276,
|
76 |
+
# Dummy tokens to account for padding of the tokenizer
|
77 |
+
# We pad to ensure tensor cores are used for vocab multiplication
|
78 |
+
**DUMMY_TOKENS
|
79 |
+
}
|
80 |
+
|
81 |
+
class Phi3SmallTokenizer(PreTrainedTokenizer):
|
82 |
+
vocab_files_names = {
|
83 |
+
"vocab_file": "cl100k_base.tiktoken"
|
84 |
+
}
|
85 |
+
|
86 |
+
model_input_names: List[str] = ["input_ids", "attention_mask"]
|
87 |
+
padding_side = "left"
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
vocab_file: Optional[str] = None,
|
92 |
+
errors: str = "replace",
|
93 |
+
**kwargs
|
94 |
+
) -> None:
|
95 |
+
# PreTrainedTokenizer's init calls _add_tokens, which in turn checks
|
96 |
+
# if the token is present in `self.special_tokens``. Hence instantiating it here.
|
97 |
+
# The way Qwen gets around this is by checking against SPECIAL_TOKENS
|
98 |
+
# But I think it's better to check against the objects own `special_tokens`
|
99 |
+
# in case we eventually want to allow the tokenizer to have special tokens.
|
100 |
+
self.special_tokens = SPECIAL_TOKENS
|
101 |
+
|
102 |
+
super().__init__(**kwargs)
|
103 |
+
self.errors = errors
|
104 |
+
|
105 |
+
base = tiktoken.get_encoding("cl100k_base")
|
106 |
+
if vocab_file is None:
|
107 |
+
self.mergeable_ranks: Dict[bytes, int] = base._mergeable_ranks
|
108 |
+
else:
|
109 |
+
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)
|
110 |
+
|
111 |
+
self.pat_str = base._pat_str
|
112 |
+
|
113 |
+
enc = tiktoken.Encoding(
|
114 |
+
name="phi3small",
|
115 |
+
pat_str=self.pat_str,
|
116 |
+
mergeable_ranks=self.mergeable_ranks,
|
117 |
+
special_tokens=self.special_tokens,
|
118 |
+
)
|
119 |
+
self.tokenizer = enc
|
120 |
+
|
121 |
+
self.decoder: Dict[int, bytes] = {
|
122 |
+
v: k for k, v in self.mergeable_ranks.items()
|
123 |
+
}
|
124 |
+
self.decoder.update({v: k for k, v in self.special_tokens.items()})
|
125 |
+
|
126 |
+
self.eod_id = self.tokenizer.eot_token
|
127 |
+
self._eos_token = self._convert_id_to_token(self.eod_id)
|
128 |
+
|
129 |
+
# Setting the bos_token to be the same as the eos_token
|
130 |
+
# Note that this is **not** the correct thing to do, and is done
|
131 |
+
# just so that some of the downstream libraries do not break.
|
132 |
+
self._bos_token = self._eos_token
|
133 |
+
|
134 |
+
# Assign the special tokens to class variables
|
135 |
+
self.system_id = self.special_tokens["<|system|>"]
|
136 |
+
self.user_id = self.special_tokens["<|user|>"]
|
137 |
+
self.assistant_id = self.special_tokens["<|assistant|>"]
|
138 |
+
self.end_id = self.special_tokens["<|end|>"]
|
139 |
+
|
140 |
+
@cached_property
|
141 |
+
def dummy_token_indices(self) -> List[int]:
|
142 |
+
# There are some additional special tokens in the cl100k_base tokenizer
|
143 |
+
# that we do not use. Hence, we also consider them to be dummy tokens.
|
144 |
+
additional_tokens = [
|
145 |
+
"<|fim_prefix|>",
|
146 |
+
"<|fim_middle|>",
|
147 |
+
"<|fim_suffix|>",
|
148 |
+
"<|endofprompt|>"
|
149 |
+
]
|
150 |
+
dummy_token_indices = [index for token, index in self.special_tokens.items() if "dummy_id" in token]
|
151 |
+
dummy_token_indices.extend([self.special_tokens[token] for token in additional_tokens])
|
152 |
+
return sorted(dummy_token_indices)
|
153 |
+
|
154 |
+
def __getstate__(self):
|
155 |
+
state = self.__dict__.copy()
|
156 |
+
del state["tokenizer"]
|
157 |
+
return state
|
158 |
+
|
159 |
+
def __setstate__(self, state):
|
160 |
+
self.__dict__ = state
|
161 |
+
enc = tiktoken.Encoding(
|
162 |
+
name="cl100k_im",
|
163 |
+
pat_str=self.pat_str,
|
164 |
+
mergeable_ranks=self.mergeable_ranks,
|
165 |
+
special_tokens=self.special_tokens,
|
166 |
+
)
|
167 |
+
self.tokenizer = enc
|
168 |
+
|
169 |
+
def __len__(self):
|
170 |
+
return self.tokenizer.n_vocab
|
171 |
+
|
172 |
+
@classmethod
|
173 |
+
def from_pretrained(
|
174 |
+
cls,
|
175 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
176 |
+
*init_inputs,
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
+
cls_kwargs = kwargs
|
180 |
+
# First try to load from the tokenization config if it exists
|
181 |
+
tokenization_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
182 |
+
if tokenization_config:
|
183 |
+
cls_kwargs = {
|
184 |
+
**tokenization_config,
|
185 |
+
**cls_kwargs
|
186 |
+
}
|
187 |
+
else:
|
188 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
189 |
+
cls_kwargs["model_max_length"] = config.max_position_embeddings
|
190 |
+
return cls(**cls_kwargs)
|
191 |
+
|
192 |
+
def get_vocab(self) -> Dict[Union[str, bytes], int]:
|
193 |
+
return {**self.mergeable_ranks, **self.special_tokens}
|
194 |
+
|
195 |
+
def convert_tokens_to_ids(
|
196 |
+
self,
|
197 |
+
tokens: Union[bytes, str, List[Union[bytes, str]]]
|
198 |
+
) -> Union[int, List[int]]:
|
199 |
+
ids = []
|
200 |
+
if isinstance(tokens, (str, bytes)):
|
201 |
+
if tokens in self.special_tokens:
|
202 |
+
return self.special_tokens[tokens]
|
203 |
+
else:
|
204 |
+
return self.mergeable_ranks.get(tokens)
|
205 |
+
ids: List[int] = []
|
206 |
+
for token in tokens:
|
207 |
+
ids.append(self.convert_tokens_to_ids(token))
|
208 |
+
return ids
|
209 |
+
|
210 |
+
def _add_tokens(
|
211 |
+
self,
|
212 |
+
new_tokens: Union[List[str], List[AddedToken]],
|
213 |
+
special_tokens: bool = False,
|
214 |
+
) -> int:
|
215 |
+
if not special_tokens and new_tokens:
|
216 |
+
raise ValueError("Only special tokens can be added to this tokenizer")
|
217 |
+
for token in new_tokens:
|
218 |
+
surface_form = token.content if isinstance(token, AddedToken) else token
|
219 |
+
if surface_form not in self.special_tokens:
|
220 |
+
raise ValueError(
|
221 |
+
"For now, we do not support unknown special tokens\n"
|
222 |
+
"In the future, if there is a need for this, we can add special tokens to the tokenizer\n"
|
223 |
+
"starting from rank 100261 - 100263 and then 100266 - 100275.\n"
|
224 |
+
"And finally, we can re-construct the enc object back\n"
|
225 |
+
)
|
226 |
+
return 0
|
227 |
+
|
228 |
+
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
229 |
+
file_path = os.path.join(save_directory, "cl100k_base.tiktoken")
|
230 |
+
with open(file_path, "w") as f:
|
231 |
+
for token, rank in self.mergeable_ranks.items():
|
232 |
+
line = base64.b64encode(token).decode("utf-8") + " " + str(rank) + "\n"
|
233 |
+
f.write(line)
|
234 |
+
return (file_path,)
|
235 |
+
|
236 |
+
def tokenize(
|
237 |
+
self,
|
238 |
+
text: str,
|
239 |
+
allowed_special: Union[Set, str] = "all",
|
240 |
+
disallowed_special: Union[Collection, str] = (),
|
241 |
+
**kwargs
|
242 |
+
) -> List[Union[bytes, str]]:
|
243 |
+
tokens: List[Union[bytes, str]] = []
|
244 |
+
for token_id in self.tokenizer.encode(
|
245 |
+
text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
246 |
+
):
|
247 |
+
tokens.append(self.decoder[token_id])
|
248 |
+
return tokens
|
249 |
+
|
250 |
+
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
|
251 |
+
"""
|
252 |
+
Converts a sequence of tokens in a single string.
|
253 |
+
"""
|
254 |
+
text = ""
|
255 |
+
temp = b""
|
256 |
+
for t in tokens:
|
257 |
+
if isinstance(t, str):
|
258 |
+
if temp:
|
259 |
+
text += temp.decode("utf-8", errors=self.errors)
|
260 |
+
temp = b""
|
261 |
+
text += t
|
262 |
+
elif isinstance(t, bytes):
|
263 |
+
temp += t
|
264 |
+
else:
|
265 |
+
raise TypeError("token should only be of type types or str")
|
266 |
+
if temp:
|
267 |
+
text += temp.decode("utf-8", errors=self.errors)
|
268 |
+
return text
|
269 |
+
|
270 |
+
@property
|
271 |
+
def vocab_size(self):
|
272 |
+
return self.tokenizer.n_vocab
|
273 |
+
|
274 |
+
@property
|
275 |
+
def eos_token_id(self) -> int:
|
276 |
+
return self.eod_id
|
277 |
+
|
278 |
+
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
|
279 |
+
"""Converts an id to a token, special tokens included"""
|
280 |
+
if index in self.decoder:
|
281 |
+
return self.decoder[index]
|
282 |
+
raise ValueError("unknown ids")
|
283 |
+
|
284 |
+
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
|
285 |
+
"""Converts a token to an id using the vocab, special tokens included"""
|
286 |
+
if token in self.special_tokens:
|
287 |
+
return self.special_tokens[token]
|
288 |
+
if token in self.mergeable_ranks:
|
289 |
+
return self.mergeable_ranks[token]
|
290 |
+
raise ValueError("unknown token")
|
291 |
+
|
292 |
+
def _tokenize(self, text: str, **kwargs):
|
293 |
+
"""
|
294 |
+
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
295 |
+
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
|
296 |
+
Do NOT take care of added tokens.
|
297 |
+
"""
|
298 |
+
raise NotImplementedError
|
299 |
+
|
300 |
+
def _decode(
|
301 |
+
self,
|
302 |
+
token_ids: Union[int, List[int]],
|
303 |
+
skip_special_tokens: bool = False,
|
304 |
+
errors: str = None,
|
305 |
+
**kwargs,
|
306 |
+
) -> str:
|
307 |
+
if isinstance(token_ids, int):
|
308 |
+
token_ids = [token_ids]
|
309 |
+
if skip_special_tokens:
|
310 |
+
token_ids = [i for i in token_ids if i < self.eod_id]
|
311 |
+
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
|
312 |
+
|
313 |
+
|
tokenizer_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_commit_hash": "fc8e001871f4a6be8e6079093b33de334a2316c9",
|
3 |
+
"_from_auto": true,
|
4 |
+
"added_tokens_decoder": {},
|
5 |
+
"auto_map": {
|
6 |
+
"AutoTokenizer": [
|
7 |
+
"tokenization_phi3_small.Phi3SmallTokenizer",
|
8 |
+
null
|
9 |
+
]
|
10 |
+
},
|
11 |
+
"bos_token": "<|endoftext|>",
|
12 |
+
"chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
13 |
+
"clean_up_tokenization_spaces": true,
|
14 |
+
"eos_token": "<|endoftext|>",
|
15 |
+
"legacy": false,
|
16 |
+
"model_max_length": 8192,
|
17 |
+
"pad_token": "<|endoftext|>",
|
18 |
+
"tokenizer_class": "Phi3SmallTokenizer",
|
19 |
+
"trust_remote_code": true
|
20 |
+
}
|
triton_blocksparse_attention_layer.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple, TypeVar
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
|
7 |
+
from functools import lru_cache
|
8 |
+
|
9 |
+
|
10 |
+
from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd
|
11 |
+
|
12 |
+
|
13 |
+
Layout = Tuple[torch.LongTensor, torch.LongTensor]
|
14 |
+
|
15 |
+
|
16 |
+
def create_sparse_attn_mask(
|
17 |
+
n_heads: int,
|
18 |
+
max_seq_len: int,
|
19 |
+
max_seq_len_k: int,
|
20 |
+
dtype: torch.dtype,
|
21 |
+
device: torch.device,
|
22 |
+
BLOCK: int,
|
23 |
+
local_blocks: int,
|
24 |
+
vert_stride: int,
|
25 |
+
homo_head: bool,
|
26 |
+
return_dense: bool
|
27 |
+
) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]:
|
28 |
+
layout, block_sparse_pattern, _ = _get_sparse_attn_mask(
|
29 |
+
n_heads=n_heads,
|
30 |
+
q_len=max_seq_len,
|
31 |
+
N_CTX=max_seq_len_k,
|
32 |
+
dtype=dtype,
|
33 |
+
device=device,
|
34 |
+
BLOCK=BLOCK,
|
35 |
+
local_blocks=local_blocks,
|
36 |
+
vert_stride=vert_stride,
|
37 |
+
homo_head=homo_head,
|
38 |
+
return_dense=return_dense
|
39 |
+
)
|
40 |
+
return layout, block_sparse_pattern
|
41 |
+
|
42 |
+
|
43 |
+
class BlockSparseAttentionLayer(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
n_heads: int,
|
47 |
+
max_seq_len: int,
|
48 |
+
sparse_block_size: int,
|
49 |
+
local_blocks: int,
|
50 |
+
vert_stride: int,
|
51 |
+
kernel_block_size: Optional[int] = None,
|
52 |
+
homo_head: bool = False,
|
53 |
+
active_head_range: Optional[Tuple[int]] = None
|
54 |
+
) -> None:
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.n_heads = n_heads
|
58 |
+
self.max_seq_len = max_seq_len
|
59 |
+
self.sparse_block_size = sparse_block_size
|
60 |
+
self.kernel_block_size = kernel_block_size or sparse_block_size
|
61 |
+
self.local_blocks = local_blocks
|
62 |
+
self.vert_stride = vert_stride
|
63 |
+
self.homo_head = homo_head
|
64 |
+
self.active_head_range = active_head_range
|
65 |
+
|
66 |
+
# Internal Parameters used by the layer
|
67 |
+
self._sparse_block_mask = None
|
68 |
+
self._sparse_layout = None
|
69 |
+
self._dtype = None
|
70 |
+
self._device = None
|
71 |
+
|
72 |
+
# TODO(bapatra): Ideally, I'd want to keep all the code for
|
73 |
+
# forward to be handled here, and not branch for training and inference.
|
74 |
+
# However, that refactor would need a lot of testing. For now, using the
|
75 |
+
# training op as is, and will refactor again later.
|
76 |
+
|
77 |
+
def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None:
|
78 |
+
self._sparse_block_mask = self._sparse_block_mask[h_start: h_end]
|
79 |
+
self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end]
|
80 |
+
self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end]
|
81 |
+
|
82 |
+
def _initialize_internals(
|
83 |
+
self,
|
84 |
+
dtype: torch.dtype,
|
85 |
+
device: torch.device
|
86 |
+
) -> None:
|
87 |
+
self._dtype, self._device = dtype, device
|
88 |
+
self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask(
|
89 |
+
n_heads=self.n_heads,
|
90 |
+
max_seq_len=self.max_seq_len,
|
91 |
+
max_seq_len_k=self.max_seq_len,
|
92 |
+
dtype=dtype,
|
93 |
+
device=device,
|
94 |
+
BLOCK=self.sparse_block_size,
|
95 |
+
local_blocks=self.local_blocks,
|
96 |
+
vert_stride=self.vert_stride,
|
97 |
+
homo_head=self.homo_head,
|
98 |
+
return_dense=False,
|
99 |
+
)
|
100 |
+
if (not self.homo_head) and (self.active_head_range is not None):
|
101 |
+
assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads."
|
102 |
+
h_start, h_end = self.active_head_range
|
103 |
+
self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end)
|
104 |
+
|
105 |
+
assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}."
|
106 |
+
assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given"
|
107 |
+
if self.sparse_block_size // self.kernel_block_size > 1:
|
108 |
+
_mul = self.sparse_block_size // self.kernel_block_size
|
109 |
+
# need to consider if block_m and block_n are different
|
110 |
+
self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul))
|
111 |
+
num_sparse_blocks = self._sparse_block_mask.size(-1)
|
112 |
+
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
|
113 |
+
self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask)
|
114 |
+
|
115 |
+
|
116 |
+
def forward(
|
117 |
+
self,
|
118 |
+
q: torch.Tensor,
|
119 |
+
k: torch.Tensor,
|
120 |
+
v: torch.Tensor,
|
121 |
+
sm_scale: float,
|
122 |
+
*,
|
123 |
+
# Arguments Related to Block Attention Inference
|
124 |
+
left_paddings: Optional[torch.LongTensor] = None,
|
125 |
+
seqlens: Optional[torch.LongTensor] = None,
|
126 |
+
# Arguements Related to Variable Length Inference
|
127 |
+
cu_seqlens_k: Optional[torch.LongTensor] = None,
|
128 |
+
cu_seqlens_q: Optional[torch.LongTensor] = None,
|
129 |
+
) -> torch.Tensor:
|
130 |
+
|
131 |
+
if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None:
|
132 |
+
blocksparse_op = get_local_strided_sparse_attention_op(
|
133 |
+
n_heads=self.n_heads,
|
134 |
+
max_seq_len=self.max_seq_len,
|
135 |
+
sparse_block_size=self.sparse_block_size,
|
136 |
+
kernel_block_size=self.kernel_block_size,
|
137 |
+
local_blocks=self.local_blocks,
|
138 |
+
vert_stride=self.vert_stride,
|
139 |
+
homo_head=self.homo_head,
|
140 |
+
device=q.device,
|
141 |
+
inference=not self.training
|
142 |
+
)
|
143 |
+
return blocksparse_op(q, k, v, sm_scale)
|
144 |
+
|
145 |
+
assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context"
|
146 |
+
# First set internals if they have not been set
|
147 |
+
if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device):
|
148 |
+
self._initialize_internals(dtype=q.dtype, device=q.device)
|
149 |
+
|
150 |
+
if k.dim() == 3:
|
151 |
+
assert cu_seqlens_k is not None
|
152 |
+
return blocksparse_flash_attn_varlen_fwd(
|
153 |
+
q=q,
|
154 |
+
k=k,
|
155 |
+
v=v,
|
156 |
+
cu_seqlens_k=cu_seqlens_k,
|
157 |
+
cu_seqlens_q=cu_seqlens_q,
|
158 |
+
sm_scale=sm_scale,
|
159 |
+
sparse_layout=self._sparse_layout,
|
160 |
+
block_size=self.kernel_block_size,
|
161 |
+
max_seqlen=self.max_seq_len,
|
162 |
+
)
|
163 |
+
if k.dim() == 4:
|
164 |
+
assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference."
|
165 |
+
return blocksparse_flash_attn_padded_fwd(
|
166 |
+
q=q,
|
167 |
+
k=k,
|
168 |
+
v=v,
|
169 |
+
sm_scale=sm_scale,
|
170 |
+
sparse_layout=self._sparse_layout,
|
171 |
+
left_paddings=left_paddings,
|
172 |
+
seqlens=seqlens,
|
173 |
+
block_size=self.kernel_block_size,
|
174 |
+
max_seqlen=self.max_seq_len,
|
175 |
+
)
|
176 |
+
raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.')
|
triton_flash_blocksparse_attn.py
ADDED
@@ -0,0 +1,1947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Eric Lin (xihlin)
|
3 |
+
"""
|
4 |
+
"""
|
5 |
+
... note(bapatra)::
|
6 |
+
This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module
|
7 |
+
imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal.
|
8 |
+
In the future, would be really good to revisit this and refactor into a more readable file structure.
|
9 |
+
|
10 |
+
"""
|
11 |
+
from typing import TypeVar
|
12 |
+
from functools import lru_cache
|
13 |
+
import math
|
14 |
+
import pytest
|
15 |
+
import torch
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import triton
|
19 |
+
import triton.language as tl
|
20 |
+
|
21 |
+
import os
|
22 |
+
|
23 |
+
import dataclasses
|
24 |
+
|
25 |
+
Phi3SmallConfig = TypeVar('Phi3SmallConfig')
|
26 |
+
|
27 |
+
# triton 2.0.0: fail at backward on A100, for the examples, if h_dim=128.
|
28 |
+
|
29 |
+
# Done
|
30 |
+
# 1. strided of qkv
|
31 |
+
# 2. seq len not power of 2
|
32 |
+
# 3. bf16 with Triton May, 2023
|
33 |
+
|
34 |
+
# TODO:
|
35 |
+
# 1. wip: support non-contiguous backward, also help reduce memory allocation in training (q, k, v split)
|
36 |
+
# 2. block sparse with different BLOCK_M, BLOCK_N?
|
37 |
+
# 3. for Lq not divided by BLOCK_M, BLOCK_N, only apply mask to K/V on last batch, still need to apply mask on Q.
|
38 |
+
# Attempt, fail to compile
|
39 |
+
# 4. For 2nd iter of inference, BLOCK_M=1, how to make things work? K/V maynot divided by BLOCK_N.
|
40 |
+
# 5. The inner loop can also be paralled via bigger num_stage(better) or on different thread-block (via m/L and atomic update, but this no-comm/sync between blocks)
|
41 |
+
|
42 |
+
|
43 |
+
###########################################################
|
44 |
+
################### Kernel Parameters #####################
|
45 |
+
###########################################################
|
46 |
+
|
47 |
+
@dataclasses.dataclass
|
48 |
+
class BlockSparseParams(object):
|
49 |
+
block_size: int
|
50 |
+
kernel_block_size: int
|
51 |
+
num_local_blocks: int
|
52 |
+
vert_stride: int
|
53 |
+
homo_head_pattern: bool = False
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams":
|
57 |
+
return cls(
|
58 |
+
block_size=config.blocksparse_block_size,
|
59 |
+
kernel_block_size=config.blocksparse_triton_kernel_block_size,
|
60 |
+
num_local_blocks=config.blocksparse_num_local_blocks,
|
61 |
+
vert_stride=config.blocksparse_vert_stride,
|
62 |
+
homo_head_pattern=config.blocksparse_homo_head_pattern,
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
###########################################################
|
67 |
+
###########################################################
|
68 |
+
|
69 |
+
###########################################################
|
70 |
+
################### Utility Functions #####################
|
71 |
+
###########################################################
|
72 |
+
|
73 |
+
# helper functions for 3D sparse pattern
|
74 |
+
# these function are not optimized and very inefficient. Avoid calling them too frequent.
|
75 |
+
# currently, it is only called within `get_local_strided_sparse_attention_op`, which is cached.
|
76 |
+
def dense_to_crow_col(x):
|
77 |
+
''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
|
78 |
+
param:
|
79 |
+
TODO:
|
80 |
+
1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it?
|
81 |
+
NOTE: col_indices padded -1
|
82 |
+
'''
|
83 |
+
pad = -1
|
84 |
+
dim = x.dim()
|
85 |
+
assert x.dim() in (2, 3)
|
86 |
+
if x.dim() == 2:
|
87 |
+
x = x[None]
|
88 |
+
x = [xi.to_sparse_csr() for xi in x]
|
89 |
+
crows = torch.vstack([xi.crow_indices() for xi in x])
|
90 |
+
cols = [xi.col_indices() for xi in x]
|
91 |
+
max_cols = max(len(xi) for xi in cols)
|
92 |
+
cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols]
|
93 |
+
cols = torch.vstack(cols)
|
94 |
+
if dim == 2:
|
95 |
+
crows = crows[0]
|
96 |
+
cols = cols[0]
|
97 |
+
return crows, cols
|
98 |
+
|
99 |
+
|
100 |
+
def crow_col_to_dense(crows, cols, dtype=torch.float16):
|
101 |
+
dim = crows.dim()
|
102 |
+
if dim == 1:
|
103 |
+
crows = crows[None]
|
104 |
+
cols = cols[None]
|
105 |
+
device = crows.device
|
106 |
+
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
|
107 |
+
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
|
108 |
+
x = torch.zeros(shape, dtype=dtype)
|
109 |
+
for i in range(shape[0]):
|
110 |
+
for j in range(shape[1]):
|
111 |
+
x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1
|
112 |
+
if dim == 1:
|
113 |
+
x = x[0]
|
114 |
+
return x.to(device)
|
115 |
+
|
116 |
+
|
117 |
+
def dense_to_ccol_row(x):
|
118 |
+
'''Similar, but to CSC format
|
119 |
+
'''
|
120 |
+
x = x.transpose(-2, -1)
|
121 |
+
return dense_to_crow_col(x)
|
122 |
+
|
123 |
+
|
124 |
+
def ccol_row_to_dense(ccol, rows, dtype=torch.float16):
|
125 |
+
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
|
126 |
+
|
127 |
+
|
128 |
+
def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False):
|
129 |
+
'''
|
130 |
+
:return: a tuple of 3:
|
131 |
+
- tuple of crow_indices, col_indices representation of CSR format.
|
132 |
+
- block dense mask
|
133 |
+
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
|
134 |
+
'''
|
135 |
+
with torch.no_grad():
|
136 |
+
N_BLOCK = triton.cdiv(N_CTX, BLOCK)
|
137 |
+
q_pos = torch.arange(N_BLOCK)[:, None]
|
138 |
+
k_pos = torch.arange(N_BLOCK)[None]
|
139 |
+
mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0
|
140 |
+
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
|
141 |
+
N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
|
142 |
+
block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr()
|
143 |
+
if return_dense:
|
144 |
+
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
|
145 |
+
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
|
146 |
+
mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask
|
147 |
+
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense
|
148 |
+
else:
|
149 |
+
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None
|
150 |
+
|
151 |
+
|
152 |
+
def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False):
|
153 |
+
'''
|
154 |
+
:return: a tuple of 3:
|
155 |
+
- tuple of crow_indices, col_indices representation of CSR format.
|
156 |
+
- block dense mask
|
157 |
+
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
|
158 |
+
'''
|
159 |
+
if homo_head:
|
160 |
+
with torch.no_grad():
|
161 |
+
(crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense)
|
162 |
+
crow = crow[None].expand(n_heads, crow.shape[0])
|
163 |
+
col = col[None].expand(n_heads, col.shape[0])
|
164 |
+
if return_dense:
|
165 |
+
mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape)
|
166 |
+
return (crow, col), block_mask_dense, mask_dense
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
N_BLOCK = triton.cdiv(N_CTX, BLOCK)
|
170 |
+
q_pos = torch.arange(N_BLOCK)[None, :, None]
|
171 |
+
k_pos = torch.arange(N_BLOCK)[None, None]
|
172 |
+
head_sliding_step = max(1, int(vert_stride / n_heads)) # if vert_stride <= n_heads, rotating the heads
|
173 |
+
mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)]
|
174 |
+
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
175 |
+
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
|
176 |
+
N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
|
177 |
+
block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:]
|
178 |
+
if return_dense:
|
179 |
+
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
|
180 |
+
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
|
181 |
+
mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None]
|
182 |
+
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense
|
183 |
+
else:
|
184 |
+
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None
|
185 |
+
|
186 |
+
|
187 |
+
def get_sparse_attn_mask(q, N_CTX, *args, **kwargs):
|
188 |
+
return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs)
|
189 |
+
|
190 |
+
###########################################################
|
191 |
+
###########################################################
|
192 |
+
|
193 |
+
###########################################################
|
194 |
+
###################### Training Kernels ###################
|
195 |
+
###########################################################
|
196 |
+
|
197 |
+
# TODO: only apply loading/saving mask on the last iteration for EVEN_N_BLOCK, useful for 1st iteration of inference.
|
198 |
+
# Experiment failed inside loop.
|
199 |
+
# Another idea: only on saving? load even out of boundary(will it causes illegal access error)?
|
200 |
+
@triton.jit
|
201 |
+
def _fwd_kernel(
|
202 |
+
Q, K, V, sm_scale,
|
203 |
+
layout_crow_ptr,
|
204 |
+
layout_col_ptr,
|
205 |
+
layout_crow_stride_h, layout_crow_stride_m,
|
206 |
+
layout_col_stride_h, layout_col_stride_m,
|
207 |
+
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug. TMP, L, M are assumed to have contiguous layouts
|
208 |
+
Out,
|
209 |
+
stride_qz, stride_qh, stride_qm, stride_qd,
|
210 |
+
stride_kz, stride_kh, stride_kn, stride_kd,
|
211 |
+
stride_vz, stride_vh, stride_vn, stride_vd,
|
212 |
+
stride_oz, stride_oh, stride_om, stride_od,
|
213 |
+
Z, H, N_CTX,
|
214 |
+
PAST_LEN,
|
215 |
+
Q_ROUNDED_LEN,
|
216 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
217 |
+
BLOCK_N: tl.constexpr,
|
218 |
+
EVEN_M_BLOCK: tl.constexpr,
|
219 |
+
EVEN_N_BLOCK: tl.constexpr,
|
220 |
+
INFERENCE: tl.constexpr,
|
221 |
+
NUM_DBLOCKS: tl.constexpr,
|
222 |
+
):
|
223 |
+
Q_LEN = N_CTX - PAST_LEN
|
224 |
+
start_m = tl.program_id(0)
|
225 |
+
off_hz = tl.program_id(1)
|
226 |
+
off_h = off_hz % H
|
227 |
+
off_z = off_hz // H
|
228 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
229 |
+
K += off_z * stride_kz + off_h * stride_kh
|
230 |
+
V += off_z * stride_vz + off_h * stride_vh
|
231 |
+
# initialize offsets
|
232 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
233 |
+
offs_n = tl.arange(0, BLOCK_N)
|
234 |
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
235 |
+
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
|
236 |
+
# off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
|
237 |
+
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
|
238 |
+
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
|
239 |
+
# Initialize pointers to Q, K, V
|
240 |
+
q_ptrs = Q + off_q
|
241 |
+
k_ptrs = K + off_k
|
242 |
+
v_ptrs = V + off_v
|
243 |
+
# initialize pointer to m and l
|
244 |
+
t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m
|
245 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
|
246 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
247 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
248 |
+
if NUM_DBLOCKS >= 2:
|
249 |
+
acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
250 |
+
|
251 |
+
# load q: it will stay in SRAM throughout
|
252 |
+
if EVEN_M_BLOCK:
|
253 |
+
q = tl.load(q_ptrs)
|
254 |
+
if NUM_DBLOCKS >= 2:
|
255 |
+
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
|
256 |
+
else:
|
257 |
+
q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)
|
258 |
+
if NUM_DBLOCKS >= 2:
|
259 |
+
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN)
|
260 |
+
|
261 |
+
layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m
|
262 |
+
start_l = tl.load(layout_ptr).to(tl.int32)
|
263 |
+
end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32)
|
264 |
+
|
265 |
+
# loop over k, v and update accumulator
|
266 |
+
for col_idx_idx in range(start_l, end_l):
|
267 |
+
col_idx = tl.load(layout_col_ptr + off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32)
|
268 |
+
start_n = col_idx * BLOCK_N
|
269 |
+
# -- compute qk ----
|
270 |
+
if EVEN_N_BLOCK:
|
271 |
+
k = tl.load(k_ptrs + start_n * stride_kn)
|
272 |
+
else:
|
273 |
+
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX)
|
274 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
275 |
+
qk += tl.dot(q, k)
|
276 |
+
|
277 |
+
if NUM_DBLOCKS >= 2:
|
278 |
+
if EVEN_N_BLOCK:
|
279 |
+
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd)
|
280 |
+
else:
|
281 |
+
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX)
|
282 |
+
qk += tl.dot(q2, k)
|
283 |
+
|
284 |
+
qk *= sm_scale
|
285 |
+
qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf'))
|
286 |
+
# -- compute m_ij, p, l_ij
|
287 |
+
m_ij = tl.max(qk, 1)
|
288 |
+
p = tl.exp(qk - m_ij[:, None])
|
289 |
+
l_ij = tl.sum(p, 1)
|
290 |
+
# -- update m_i and l_i
|
291 |
+
m_i_new = tl.maximum(m_i, m_ij)
|
292 |
+
alpha = tl.exp(m_i - m_i_new)
|
293 |
+
beta = tl.exp(m_ij - m_i_new)
|
294 |
+
l_i_new = alpha * l_i + beta * l_ij
|
295 |
+
# -- update output accumulator --
|
296 |
+
# scale p
|
297 |
+
p_scale = beta / l_i_new
|
298 |
+
p = p * p_scale[:, None]
|
299 |
+
# scale acc
|
300 |
+
acc_scale = l_i / l_i_new * alpha
|
301 |
+
# tl.store(t_ptrs, acc_scale)
|
302 |
+
# acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
303 |
+
acc = acc * acc_scale[:, None]
|
304 |
+
if NUM_DBLOCKS >= 2:
|
305 |
+
acc2 = acc2 * acc_scale[:, None]
|
306 |
+
p = p.to(Q.dtype.element_ty)
|
307 |
+
# update acc
|
308 |
+
if EVEN_N_BLOCK:
|
309 |
+
v = tl.load(v_ptrs + start_n * stride_vn)
|
310 |
+
else:
|
311 |
+
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX)
|
312 |
+
acc += tl.dot(p, v)
|
313 |
+
|
314 |
+
if NUM_DBLOCKS >= 2:
|
315 |
+
if EVEN_N_BLOCK:
|
316 |
+
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd)
|
317 |
+
else:
|
318 |
+
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX)
|
319 |
+
acc2 += tl.dot(p, v)
|
320 |
+
|
321 |
+
# update m_i and l_i
|
322 |
+
l_i = l_i_new
|
323 |
+
m_i = m_i_new
|
324 |
+
|
325 |
+
# rematerialize offsets to save registers
|
326 |
+
# start_m = tl.program_id(0)
|
327 |
+
# offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
328 |
+
# write back l and m
|
329 |
+
if not INFERENCE:
|
330 |
+
l_ptrs = L + off_hz * N_CTX + offs_m
|
331 |
+
m_ptrs = M + off_hz * N_CTX + offs_m
|
332 |
+
if EVEN_M_BLOCK:
|
333 |
+
tl.store(l_ptrs, l_i)
|
334 |
+
tl.store(m_ptrs, m_i)
|
335 |
+
else:
|
336 |
+
tl.store(l_ptrs, l_i, mask=offs_m < Q_LEN)
|
337 |
+
tl.store(m_ptrs, m_i, mask=offs_m < Q_LEN)
|
338 |
+
# initialize pointers to output
|
339 |
+
# offs_n = tl.arange(0, BLOCK_DMODEL)
|
340 |
+
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
|
341 |
+
out_ptrs = Out + off_o
|
342 |
+
tl.store(out_ptrs, acc, mask=offs_m[:, None] < Q_LEN)
|
343 |
+
if NUM_DBLOCKS >= 2:
|
344 |
+
tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2, mask=offs_m[:, None] < Q_LEN)
|
345 |
+
|
346 |
+
|
347 |
+
## backward
|
348 |
+
@triton.heuristics(
|
349 |
+
{
|
350 |
+
'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
|
351 |
+
}
|
352 |
+
)
|
353 |
+
@triton.jit
|
354 |
+
def _bwd_preprocess(
|
355 |
+
Out, DO, L, # assume contiguous for Out, DO, L, NewDO, Delta layout.
|
356 |
+
NewDO, Delta,
|
357 |
+
N_CTX,
|
358 |
+
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
359 |
+
EVEN_M_BLOCK: tl.constexpr,
|
360 |
+
):
|
361 |
+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
362 |
+
off_d = tl.arange(0, D_HEAD)
|
363 |
+
# load
|
364 |
+
if EVEN_M_BLOCK:
|
365 |
+
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
|
366 |
+
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
|
367 |
+
else:
|
368 |
+
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
|
369 |
+
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
|
370 |
+
denom = tl.load(L + off_m).to(tl.float32)
|
371 |
+
# compute
|
372 |
+
do = do / denom[:, None]
|
373 |
+
delta = tl.sum(o * do, axis=1)
|
374 |
+
# write-back
|
375 |
+
if EVEN_M_BLOCK:
|
376 |
+
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do)
|
377 |
+
else:
|
378 |
+
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do, mask=off_m[:, None] < N_CTX)
|
379 |
+
tl.store(Delta + off_m, delta)
|
380 |
+
|
381 |
+
|
382 |
+
# Does not suuport unequal seqlen(q) and seqlen(k)
|
383 |
+
@triton.heuristics(
|
384 |
+
{
|
385 |
+
'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
|
386 |
+
'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0,
|
387 |
+
}
|
388 |
+
)
|
389 |
+
@triton.jit
|
390 |
+
def _bwd_kernel(
|
391 |
+
Q, K, V, sm_scale,
|
392 |
+
layout_ccol_ptr,
|
393 |
+
layout_row_ptr,
|
394 |
+
layout_ccol_stride_h, layout_ccol_stride_m,
|
395 |
+
layout_row_stride_h, layout_row_stride_m,
|
396 |
+
Out, DO, # assume contigous: Out, Do, DQ, DK, DV, L, M, D, seq(q) == seq(k), with stride_oz, stride_oh, stride_om, stride_od,
|
397 |
+
DQ, DK, DV,
|
398 |
+
L, M,
|
399 |
+
D,
|
400 |
+
stride_qz, stride_qh, stride_qm, stride_qd,
|
401 |
+
stride_kz, stride_kh, stride_kn, stride_kd,
|
402 |
+
stride_vz, stride_vh, stride_vn, stride_vd,
|
403 |
+
stride_oz, stride_oh, stride_om, stride_od,
|
404 |
+
# stride_dz, stride_dh, stride_dm, stride_dd,
|
405 |
+
Z, H, N_CTX,
|
406 |
+
num_block,
|
407 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
408 |
+
BLOCK_N: tl.constexpr,
|
409 |
+
EVEN_M_BLOCK: tl.constexpr,
|
410 |
+
EVEN_N_BLOCK: tl.constexpr,
|
411 |
+
NUM_DBLOCKS: tl.constexpr,
|
412 |
+
):
|
413 |
+
start_n = tl.program_id(0)
|
414 |
+
off_hz = tl.program_id(1)
|
415 |
+
off_z = off_hz // H
|
416 |
+
off_h = off_hz % H
|
417 |
+
# offset pointers for batch/head
|
418 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
419 |
+
K += off_z * stride_kz + off_h * stride_kh
|
420 |
+
V += off_z * stride_vz + off_h * stride_vh
|
421 |
+
DO += off_z * stride_oz + off_h * stride_oh
|
422 |
+
DQ += off_z * stride_oz + off_h * stride_oh
|
423 |
+
DK += off_z * stride_oz + off_h * stride_oh
|
424 |
+
DV += off_z * stride_oz + off_h * stride_oh
|
425 |
+
# Look like this loop can be parallelled
|
426 |
+
# for start_n in range(0, num_block):
|
427 |
+
|
428 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
429 |
+
offs_m = tl.arange(0, BLOCK_M)
|
430 |
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
431 |
+
# initialize pointers to value-like data
|
432 |
+
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
|
433 |
+
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)
|
434 |
+
|
435 |
+
# pointer to row-wise quantities in value-like data
|
436 |
+
D_ptrs = D + off_hz * N_CTX
|
437 |
+
m_ptrs = M + off_hz * N_CTX
|
438 |
+
# initialize dv amd dk
|
439 |
+
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
440 |
+
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
441 |
+
# k and v stay in SRAM throughout
|
442 |
+
if EVEN_N_BLOCK:
|
443 |
+
k = tl.load(k_ptrs)
|
444 |
+
v = tl.load(v_ptrs)
|
445 |
+
else:
|
446 |
+
k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX)
|
447 |
+
v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX)
|
448 |
+
|
449 |
+
if NUM_DBLOCKS >= 2:
|
450 |
+
dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
451 |
+
dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
452 |
+
if EVEN_N_BLOCK:
|
453 |
+
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd)
|
454 |
+
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd)
|
455 |
+
else:
|
456 |
+
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX)
|
457 |
+
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX)
|
458 |
+
|
459 |
+
# loop over rows
|
460 |
+
|
461 |
+
layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m
|
462 |
+
start_l = tl.load(layout_ptr).to(tl.int32)
|
463 |
+
end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32)
|
464 |
+
|
465 |
+
for row_idx_idx in range(start_l, end_l):
|
466 |
+
row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32)
|
467 |
+
start_m = row_idx * BLOCK_M
|
468 |
+
|
469 |
+
# offs_qm = start_m + tl.arange(0, BLOCK_M)
|
470 |
+
offs_m_curr = start_m + offs_m
|
471 |
+
q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd)
|
472 |
+
do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
|
473 |
+
dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
|
474 |
+
|
475 |
+
# load q, k, v, do on-chip
|
476 |
+
if EVEN_M_BLOCK:
|
477 |
+
q = tl.load(q_ptrs)
|
478 |
+
else:
|
479 |
+
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX)
|
480 |
+
# re-compute p = softmax(qk, dim=-1).T
|
481 |
+
# NOTE: `do` is pre-divided by `l`; no normalization here
|
482 |
+
qk = tl.dot(q, tl.trans(k))
|
483 |
+
|
484 |
+
if NUM_DBLOCKS >= 2:
|
485 |
+
if EVEN_M_BLOCK:
|
486 |
+
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
|
487 |
+
else:
|
488 |
+
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX)
|
489 |
+
qk += tl.dot(q2, tl.trans(k2))
|
490 |
+
|
491 |
+
qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf'))
|
492 |
+
|
493 |
+
if EVEN_M_BLOCK:
|
494 |
+
m = tl.load(m_ptrs + offs_m_curr)
|
495 |
+
else:
|
496 |
+
m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
|
497 |
+
p = tl.exp(qk * sm_scale - m[:, None])
|
498 |
+
|
499 |
+
# compute dv
|
500 |
+
if EVEN_M_BLOCK:
|
501 |
+
do = tl.load(do_ptrs)
|
502 |
+
else:
|
503 |
+
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX)
|
504 |
+
|
505 |
+
if NUM_DBLOCKS >= 2:
|
506 |
+
if EVEN_M_BLOCK:
|
507 |
+
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od)
|
508 |
+
else:
|
509 |
+
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX)
|
510 |
+
|
511 |
+
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
512 |
+
|
513 |
+
if NUM_DBLOCKS >= 2:
|
514 |
+
dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2)
|
515 |
+
|
516 |
+
# compute dp = dot(v, do)
|
517 |
+
if EVEN_M_BLOCK:
|
518 |
+
Di = tl.load(D_ptrs + offs_m_curr)
|
519 |
+
else:
|
520 |
+
Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
|
521 |
+
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
522 |
+
dp += tl.dot(do, tl.trans(v))
|
523 |
+
|
524 |
+
if NUM_DBLOCKS >= 2:
|
525 |
+
dp += tl.dot(do2, tl.trans(v2))
|
526 |
+
|
527 |
+
# compute ds = p * (dp - delta[:, None])
|
528 |
+
ds = p * dp * sm_scale
|
529 |
+
# compute dk = dot(ds.T, q)
|
530 |
+
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
531 |
+
if NUM_DBLOCKS >= 2:
|
532 |
+
dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2)
|
533 |
+
|
534 |
+
# # compute dq
|
535 |
+
dq = tl.dot(ds.to(Q.dtype.element_ty), k)
|
536 |
+
if EVEN_M_BLOCK:
|
537 |
+
tl.atomic_add(dq_ptrs, dq)
|
538 |
+
else:
|
539 |
+
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX)
|
540 |
+
|
541 |
+
if NUM_DBLOCKS >= 2:
|
542 |
+
dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2)
|
543 |
+
dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od
|
544 |
+
if EVEN_M_BLOCK:
|
545 |
+
tl.atomic_add(dq_ptrs2, dq2)
|
546 |
+
else:
|
547 |
+
tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX)
|
548 |
+
|
549 |
+
# write-back
|
550 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
|
551 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
|
552 |
+
if EVEN_N_BLOCK:
|
553 |
+
tl.store(dv_ptrs, dv)
|
554 |
+
tl.store(dk_ptrs, dk)
|
555 |
+
else:
|
556 |
+
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX)
|
557 |
+
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX)
|
558 |
+
|
559 |
+
if NUM_DBLOCKS >= 2:
|
560 |
+
dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od
|
561 |
+
dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od
|
562 |
+
if EVEN_N_BLOCK:
|
563 |
+
tl.store(dv_ptrs2, dv2)
|
564 |
+
tl.store(dk_ptrs2, dk2)
|
565 |
+
else:
|
566 |
+
tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX)
|
567 |
+
tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX)
|
568 |
+
|
569 |
+
|
570 |
+
|
571 |
+
def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None):
|
572 |
+
'''
|
573 |
+
:param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v.
|
574 |
+
:param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor.
|
575 |
+
Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all..
|
576 |
+
'''
|
577 |
+
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
578 |
+
assert k.shape[2] == v.shape[2]
|
579 |
+
o = out if out is not None else torch.empty_like(q).contiguous()
|
580 |
+
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
|
581 |
+
|
582 |
+
q_rounded_len = grid[0] * BLOCK_M
|
583 |
+
tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
|
584 |
+
|
585 |
+
if inference is None:
|
586 |
+
inference = (not q.requires_grad) and (not k.requires_grad) and (not v.requires_grad)
|
587 |
+
|
588 |
+
if inference:
|
589 |
+
L, m = tmp, tmp # no need to use create new tensor
|
590 |
+
else:
|
591 |
+
L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
|
592 |
+
m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
|
593 |
+
|
594 |
+
if layout_col_indices.dim() == 1:
|
595 |
+
layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1)
|
596 |
+
layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1)
|
597 |
+
|
598 |
+
assert q.shape[-1] in [64, 128]
|
599 |
+
BLOCK_DMODEL = 64
|
600 |
+
|
601 |
+
if num_warps is None:
|
602 |
+
MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL)
|
603 |
+
num_warps = max(1, 2 ** int(math.log2(MIN_D / 16)))
|
604 |
+
# print(f'> {BLOCK_M=}, {BLOCK_N=}, {BLOCK_DMODEL=}, {num_warps=}, {num_stages=}')
|
605 |
+
else:
|
606 |
+
assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.'''
|
607 |
+
|
608 |
+
## For debugging:
|
609 |
+
# print(f'>> {q.shape=}, {k.shape=}, {BLOCK_M=}, {BLOCK_N=}, {num_warps=}, {BLOCK_DMODEL=}, {q.stride()=}, {k.stride()=}')
|
610 |
+
# print(f'>> {layout_crow_indices=}\n{layout_col_indices=}\n {layout_crow_indices.stride()=}, {layout_crow_indices.stride()=}')
|
611 |
+
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
612 |
+
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
613 |
+
|
614 |
+
with torch.cuda.device(q.device.index):
|
615 |
+
_fwd_kernel[grid](
|
616 |
+
q, k, v, sm_scale,
|
617 |
+
layout_crow_indices,
|
618 |
+
layout_col_indices,
|
619 |
+
layout_crow_indices.stride(0), layout_crow_indices.stride(1),
|
620 |
+
layout_col_indices.stride(0), layout_col_indices.stride(1),
|
621 |
+
tmp, L, m,
|
622 |
+
o,
|
623 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
624 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
625 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
626 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
627 |
+
q.shape[0], q.shape[1], k.shape[2],
|
628 |
+
k.shape[2] - q.shape[2],
|
629 |
+
q_rounded_len,
|
630 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
631 |
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
632 |
+
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
|
633 |
+
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
|
634 |
+
INFERENCE=inference,
|
635 |
+
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
|
636 |
+
num_warps=num_warps,
|
637 |
+
num_stages=num_stages,
|
638 |
+
)
|
639 |
+
if inference:
|
640 |
+
L, m = None, None
|
641 |
+
|
642 |
+
ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices, layout_col_indices)
|
643 |
+
ctx.BLOCK_M = BLOCK_M
|
644 |
+
ctx.BLOCK_N = BLOCK_N
|
645 |
+
ctx.BLOCK_DMODEL = BLOCK_DMODEL
|
646 |
+
# ctx.BLOCK = BLOCK
|
647 |
+
ctx.grid = grid
|
648 |
+
ctx.sm_scale = sm_scale
|
649 |
+
ctx.num_warps = num_warps
|
650 |
+
ctx.num_stages = num_stages
|
651 |
+
return o
|
652 |
+
|
653 |
+
|
654 |
+
def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None):
|
655 |
+
# q, k, v, o, l, m = ctx.saved_tensors
|
656 |
+
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
|
657 |
+
|
658 |
+
## this following too slow to do online, so get it from inputs, which is cached.
|
659 |
+
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
|
660 |
+
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
|
661 |
+
|
662 |
+
if not do.is_contiguous():
|
663 |
+
do = do.contiguous()
|
664 |
+
## for debugging
|
665 |
+
# print(f'----> do is not contiguous: {do.stride()=}')
|
666 |
+
# raise ValueError(f'>>>> output grad is not contiguous: {do.stride()=}')
|
667 |
+
|
668 |
+
if not o.is_contiguous():
|
669 |
+
# TODO: currently only work with contiguous q/k/v.
|
670 |
+
raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.')
|
671 |
+
|
672 |
+
|
673 |
+
if layout_ccol_indices.dim() == 1:
|
674 |
+
layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1)
|
675 |
+
layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1)
|
676 |
+
|
677 |
+
# do = do.contiguous()
|
678 |
+
dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32)
|
679 |
+
dk = dk if dk is not None else torch.empty_like(k)
|
680 |
+
dv =dv if dv is not None else torch.empty_like(v)
|
681 |
+
do_scaled = torch.empty_like(do)
|
682 |
+
delta = torch.empty_like(l)
|
683 |
+
|
684 |
+
assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride()
|
685 |
+
|
686 |
+
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
687 |
+
o, do, l,
|
688 |
+
do_scaled, delta,
|
689 |
+
k.shape[2],
|
690 |
+
BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1],
|
691 |
+
)
|
692 |
+
|
693 |
+
grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1])
|
694 |
+
|
695 |
+
_bwd_kernel[grid](
|
696 |
+
q, k, v, ctx.sm_scale,
|
697 |
+
layout_ccol_indices,
|
698 |
+
layout_row_indices,
|
699 |
+
layout_ccol_indices.stride(0), layout_ccol_indices.stride(1),
|
700 |
+
layout_row_indices.stride(0), layout_row_indices.stride(1),
|
701 |
+
o, do_scaled,
|
702 |
+
dq, dk, dv,
|
703 |
+
l, m,
|
704 |
+
delta,
|
705 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
706 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
707 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
708 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
709 |
+
q.shape[0], q.shape[1], q.shape[2],
|
710 |
+
ctx.grid[0],
|
711 |
+
BLOCK_M=ctx.BLOCK_M,
|
712 |
+
BLOCK_N=ctx.BLOCK_N,
|
713 |
+
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
714 |
+
NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL,
|
715 |
+
num_warps=ctx.num_warps,
|
716 |
+
num_stages=1,
|
717 |
+
)
|
718 |
+
return dq, dk, dv, None, None, None
|
719 |
+
|
720 |
+
|
721 |
+
class _sparse_attention(torch.autograd.Function):
|
722 |
+
|
723 |
+
@staticmethod
|
724 |
+
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
|
725 |
+
BLOCK = 128
|
726 |
+
# shape constraints
|
727 |
+
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK)
|
728 |
+
|
729 |
+
@staticmethod
|
730 |
+
def backward(ctx, do):
|
731 |
+
# q, k, v, o, l, m = ctx.saved_tensors
|
732 |
+
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
|
733 |
+
# TODO: the following is very inefficient.
|
734 |
+
# layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
|
735 |
+
layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
|
736 |
+
return _backward(ctx, do, layout_ccol_indices, layout_row_indices)
|
737 |
+
|
738 |
+
|
739 |
+
|
740 |
+
# suppressed
|
741 |
+
class _sparse_attention_inference(_sparse_attention):
|
742 |
+
# TODO: does not work now, as BLOCK_M cannot be <1, as shape for tl.dot cannot be smaller than 16.
|
743 |
+
@staticmethod
|
744 |
+
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
|
745 |
+
BLOCK = 128
|
746 |
+
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK)
|
747 |
+
|
748 |
+
|
749 |
+
|
750 |
+
def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs):
|
751 |
+
class _sparse_attention_config(_sparse_attention):
|
752 |
+
@staticmethod
|
753 |
+
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
|
754 |
+
# shape constraints
|
755 |
+
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
|
756 |
+
**kwargs
|
757 |
+
)
|
758 |
+
return _sparse_attention_config.apply
|
759 |
+
|
760 |
+
|
761 |
+
@lru_cache(maxsize=8)
|
762 |
+
def get_local_strided_sparse_attention_op(
|
763 |
+
n_heads: int,
|
764 |
+
max_seq_len:int,
|
765 |
+
sparse_block_size: int=128,
|
766 |
+
local_blocks: int=4,
|
767 |
+
vert_stride: int=4,
|
768 |
+
homo_head: bool=False,
|
769 |
+
dtype=torch.bfloat16,
|
770 |
+
device='cuda',
|
771 |
+
active_head_range=None,
|
772 |
+
verbose=True,
|
773 |
+
**kwargs):
|
774 |
+
'''
|
775 |
+
:param n_heads: total number of attention heads (regardless of tensor/model parallel)
|
776 |
+
:param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences.
|
777 |
+
:param sparse_block_size: sparse block size. Default to 128
|
778 |
+
:param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens.
|
779 |
+
:param vert_stride: Default to 4. Meaning
|
780 |
+
:param homo_head: if all head shared the same pattern.
|
781 |
+
:param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads.
|
782 |
+
Mainly for tensor/model parallelization where heads are splitted to different GPUs.
|
783 |
+
'''
|
784 |
+
|
785 |
+
if verbose:
|
786 |
+
print((f'> new block_sparse_attn op constructed with config: '
|
787 |
+
f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, '
|
788 |
+
f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}'))
|
789 |
+
# assert math.log2(max_seq_len) % 2 == 0, f"max_seq_len should be power of 2 to be more efficient"
|
790 |
+
_, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device,
|
791 |
+
BLOCK=sparse_block_size, local_blocks=local_blocks,
|
792 |
+
vert_stride=vert_stride, homo_head=homo_head,
|
793 |
+
return_dense=False)
|
794 |
+
if (not homo_head) and (active_head_range is not None):
|
795 |
+
assert isinstance(active_head_range, tuple)
|
796 |
+
assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.'
|
797 |
+
h_start, h_end = active_head_range
|
798 |
+
block_sparse_pattern = block_sparse_pattern[h_start:h_end]
|
799 |
+
# print(block_sparse_pattern)
|
800 |
+
return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs)
|
801 |
+
|
802 |
+
|
803 |
+
def get_sparse_attn_op(
|
804 |
+
sparse_pattern: torch.tensor,
|
805 |
+
sparse_block_size: int=128,
|
806 |
+
kernel_block_size=128,
|
807 |
+
qkv_format='q,k,v',
|
808 |
+
**kwargs):
|
809 |
+
'''
|
810 |
+
Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime,
|
811 |
+
which is very inefficient (use python loops on CPU. PyTorch 1.13 supports CSR->CSC, may help.)
|
812 |
+
|
813 |
+
:param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`.
|
814 |
+
This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention
|
815 |
+
:param sparse_block_size: sparse block size. Default to 128
|
816 |
+
:param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size`
|
817 |
+
:param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported.
|
818 |
+
|
819 |
+
:param kwargs: keyward arguments passed to `_forward`
|
820 |
+
'''
|
821 |
+
# assert qkv_format in ('q,k,v', 'q, kv', 'qkv') # to save from running `concat` at forward/backward
|
822 |
+
|
823 |
+
assert qkv_format == 'q,k,v'
|
824 |
+
|
825 |
+
if kernel_block_size is None:
|
826 |
+
kernel_block_size = sparse_block_size
|
827 |
+
else:
|
828 |
+
assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}."
|
829 |
+
assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given"
|
830 |
+
|
831 |
+
|
832 |
+
# print(f'>> {sparse_pattern.shape=}')
|
833 |
+
# print(f'{sparse_pattern=}')
|
834 |
+
if sparse_block_size // kernel_block_size > 1:
|
835 |
+
_mul = sparse_block_size // kernel_block_size
|
836 |
+
# need to consider if block_m and block_n are different
|
837 |
+
sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul))
|
838 |
+
num_sparse_blocks = sparse_pattern.size(-1)
|
839 |
+
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
|
840 |
+
sparse_pattern *= block_causal_mask.type_as(sparse_pattern)
|
841 |
+
# print(f'>> after: {sparse_pattern.shape=}')
|
842 |
+
# print(f'{sparse_pattern=}')
|
843 |
+
|
844 |
+
BLOCK_N = kernel_block_size
|
845 |
+
NUM_BLOCK = sparse_pattern.size(-1)
|
846 |
+
MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK
|
847 |
+
|
848 |
+
grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern)
|
849 |
+
# sparse csc layout for backward
|
850 |
+
grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern)
|
851 |
+
|
852 |
+
|
853 |
+
# cache GPU backward layout. limit the size to avoid OOM as time goes.
|
854 |
+
# For inference, one only needs to cache one block as sequence length always increases
|
855 |
+
# Therefore, this cache needs to be reconstructed per every `block_size`-steps.
|
856 |
+
# For training/finetune, set to 8 to increase cache hit.
|
857 |
+
# Given an input, the block_len will be the same for all layers, so cache is very helpful.
|
858 |
+
|
859 |
+
max_cache_size = 1 if kwargs.get('inference', False) else 8
|
860 |
+
|
861 |
+
@lru_cache(maxsize=max_cache_size)
|
862 |
+
def get_backward_layout_by_block_len(block_len):
|
863 |
+
assert block_len <= NUM_BLOCK
|
864 |
+
if block_len == NUM_BLOCK:
|
865 |
+
return (grand_layout_ccol_indices, grand_layout_row_indices)
|
866 |
+
return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len])
|
867 |
+
|
868 |
+
# for debugging
|
869 |
+
# if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
870 |
+
# print(f'> {sparse_pattern.cpu().tolist()=}')
|
871 |
+
# print('----')
|
872 |
+
# print(f'> {grand_layout_crow_indices.cpu().tolist()=}\n{grand_layout_col_indices.cpu().tolist()=}')
|
873 |
+
|
874 |
+
|
875 |
+
# q, k, v separated
|
876 |
+
class _q_k_v_sparse_attention(torch.autograd.Function):
|
877 |
+
@staticmethod
|
878 |
+
def forward(ctx, q, k, v, sm_scale):
|
879 |
+
# assert q.shape[2] == 1 or q.shape[2] == k.shape[2]
|
880 |
+
# shape constraints
|
881 |
+
MIN_BLOCK_SIZE = 16
|
882 |
+
assert BLOCK_N >= MIN_BLOCK_SIZE
|
883 |
+
BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N # BLOCK_M has to be power of 2
|
884 |
+
|
885 |
+
# this following code only works for causal attention
|
886 |
+
K_BLOCKS = triton.cdiv(k.shape[2], kernel_block_size)
|
887 |
+
# Q_START_BLOCKS = K_BLOCKS - 1 if q.shape[2] == 1 else 0
|
888 |
+
Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N)
|
889 |
+
# print(Q_START_BLOCKS, K_BLOCKS)
|
890 |
+
|
891 |
+
layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1]
|
892 |
+
layout_col_indices = grand_layout_col_indices
|
893 |
+
# print(BLOCK_M, BLOCK_N, Q_START_BLOCKS, K_BLOCKS+1, layout_crow_indices, layout_col_indices)
|
894 |
+
|
895 |
+
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
|
896 |
+
**kwargs
|
897 |
+
)
|
898 |
+
@staticmethod
|
899 |
+
def backward(ctx, do):
|
900 |
+
q, k = ctx.saved_tensors[:2]
|
901 |
+
assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.'
|
902 |
+
# assume q, k have same length
|
903 |
+
block_len = triton.cdiv(do.shape[2], kernel_block_size)
|
904 |
+
backward_layout = get_backward_layout_by_block_len(block_len)
|
905 |
+
return _backward(ctx, do, *backward_layout)[:4]
|
906 |
+
|
907 |
+
|
908 |
+
def _q_k_v_sparse_attention_fn(*args):
|
909 |
+
return _q_k_v_sparse_attention.apply(*args)
|
910 |
+
|
911 |
+
_q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern
|
912 |
+
_q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices
|
913 |
+
_q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices
|
914 |
+
_q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices
|
915 |
+
_q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices
|
916 |
+
|
917 |
+
return _q_k_v_sparse_attention_fn
|
918 |
+
|
919 |
+
###########################################################
|
920 |
+
###########################################################
|
921 |
+
|
922 |
+
###########################################################
|
923 |
+
################ Inference Kernels ########################
|
924 |
+
###########################################################
|
925 |
+
|
926 |
+
def blocksparse_flash_attn_padded_fwd(
|
927 |
+
q, k, v, # (batch, tokens, n_heads, head_size)
|
928 |
+
sm_scale,
|
929 |
+
sparse_layout,
|
930 |
+
*,
|
931 |
+
left_paddings = None,
|
932 |
+
seqlens = None,
|
933 |
+
block_size = 64,
|
934 |
+
max_seqlen = None
|
935 |
+
):
|
936 |
+
'''
|
937 |
+
q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size)
|
938 |
+
left_paddings: (batch, ), number of left paddings for each sample.
|
939 |
+
seqlens: can be used to specify right padding. No need to specify if left_paddings is used.
|
940 |
+
'''
|
941 |
+
batches, q_len, n_heads, head_size = q.shape
|
942 |
+
_, k_len, n_kv_heads, _ = k.shape
|
943 |
+
|
944 |
+
|
945 |
+
assert q.dim() == k.dim() == v.dim() == 4
|
946 |
+
assert q.size(2) % k.size(2) == 0
|
947 |
+
assert q.size(0) == k.size(0) and q.size(3) == k.size(3)
|
948 |
+
assert k.shape == v.shape # TODO: allow diff head_size for k, v
|
949 |
+
assert q_len == 1 or q_len == k_len, \
|
950 |
+
f'q length can only 1 for decoding for same as k length for prefilling.'
|
951 |
+
|
952 |
+
q_k_ratio = q.size(2) // k.size(2)
|
953 |
+
|
954 |
+
if max_seqlen:
|
955 |
+
assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.'
|
956 |
+
|
957 |
+
# paddings always has zero output, a little slower than using empty
|
958 |
+
out = q.new_zeros(q.shape)
|
959 |
+
|
960 |
+
layout_crow_indices, layout_col_indices = sparse_layout
|
961 |
+
block_d = triton.next_power_of_2(head_size)
|
962 |
+
|
963 |
+
if left_paddings is not None:
|
964 |
+
assert left_paddings.shape == (batches,)
|
965 |
+
k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous()
|
966 |
+
else:
|
967 |
+
k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device)
|
968 |
+
|
969 |
+
if seqlens is not None:
|
970 |
+
k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts)
|
971 |
+
assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.'
|
972 |
+
else:
|
973 |
+
k_batch_ends = torch.zeros_like(k_batch_starts) + k_len
|
974 |
+
|
975 |
+
if q_len == 1:
|
976 |
+
q_batch_starts = torch.zeros_like(k_batch_starts)
|
977 |
+
q_batch_ends = q_batch_starts + 1
|
978 |
+
else:
|
979 |
+
q_batch_starts = k_batch_starts
|
980 |
+
q_batch_ends = k_batch_ends
|
981 |
+
|
982 |
+
# switch to use cpu to avoid too many kernel lauch when iterate over
|
983 |
+
q_lens = (q_batch_ends - q_batch_starts).cpu()
|
984 |
+
n_blocks = (q_lens + block_size - 1) // block_size
|
985 |
+
|
986 |
+
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
|
987 |
+
dtype=q_batch_starts.dtype,
|
988 |
+
device=q_batch_starts.device)
|
989 |
+
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
|
990 |
+
dtype=q_batch_starts.dtype,
|
991 |
+
device=q_batch_starts.device)
|
992 |
+
|
993 |
+
grid = (len(q_start_sids), n_heads)
|
994 |
+
|
995 |
+
with torch.cuda.device(q.device.index):
|
996 |
+
_fwd_kernel_batch_inference[grid](
|
997 |
+
q, k, v, out,
|
998 |
+
sm_scale,
|
999 |
+
q_batch_starts,
|
1000 |
+
q_batch_ends,
|
1001 |
+
k_batch_starts,
|
1002 |
+
k_batch_ends,
|
1003 |
+
q_batch_ids,
|
1004 |
+
q_start_sids,
|
1005 |
+
|
1006 |
+
*q.stride(),
|
1007 |
+
*k.stride(),
|
1008 |
+
*v.stride(),
|
1009 |
+
*out.stride(),
|
1010 |
+
|
1011 |
+
layout_crow_indices,
|
1012 |
+
layout_col_indices,
|
1013 |
+
*layout_crow_indices.stride(),
|
1014 |
+
*layout_col_indices.stride(),
|
1015 |
+
|
1016 |
+
q_k_ratio,
|
1017 |
+
HAS_BATCH_DIM = True,
|
1018 |
+
D_HEAD = head_size,
|
1019 |
+
BLOCK_M = block_size,
|
1020 |
+
BLOCK_N = block_size,
|
1021 |
+
BLOCK_D = block_d,
|
1022 |
+
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
|
1023 |
+
EVEN_D = block_d == head_size,
|
1024 |
+
num_warps = 1 if q_len == 1 else 4,
|
1025 |
+
num_stages = 1
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
|
1029 |
+
return out
|
1030 |
+
|
1031 |
+
|
1032 |
+
def blocksparse_flash_attn_varlen_fwd(
|
1033 |
+
q, k, v, # (#tokens, n_heads, head_size)
|
1034 |
+
cu_seqlens_k,
|
1035 |
+
cu_seqlens_q,
|
1036 |
+
sm_scale,
|
1037 |
+
sparse_layout,
|
1038 |
+
*,
|
1039 |
+
block_size=64,
|
1040 |
+
max_seqlen = None
|
1041 |
+
):
|
1042 |
+
# split q to blocks
|
1043 |
+
_, n_heads, head_size = q.shape
|
1044 |
+
batch_size = cu_seqlens_k.size(0) - 1
|
1045 |
+
|
1046 |
+
|
1047 |
+
# print(f'> {q.shape=}, {k.shape=}')
|
1048 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
1049 |
+
assert q.size(1) % k.size(1) == 0
|
1050 |
+
assert q.size(2) == k.size(2)
|
1051 |
+
assert k.shape == v.shape # TODO: allow diff head_size for k, v
|
1052 |
+
assert cu_seqlens_k.dim() == 1
|
1053 |
+
|
1054 |
+
q_k_ratio = q.size(1) // k.size(1)
|
1055 |
+
|
1056 |
+
if cu_seqlens_q is None:
|
1057 |
+
if q.size(0) == batch_size: # decoding only
|
1058 |
+
cu_seqlens_q = torch.arange(0, batch_size + 1,
|
1059 |
+
dtype=cu_seqlens_k.dtype,
|
1060 |
+
device=cu_seqlens_k.device)
|
1061 |
+
elif q.size(0) == k.size(0):
|
1062 |
+
cu_seqlens_q = cu_seqlens_k
|
1063 |
+
else:
|
1064 |
+
raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.')
|
1065 |
+
else:
|
1066 |
+
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
|
1067 |
+
|
1068 |
+
# switch to use cpu to avoid too many kernel lauch when iterate over
|
1069 |
+
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
|
1070 |
+
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
|
1071 |
+
|
1072 |
+
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \
|
1073 |
+
'length of q should either be 1 (decoding) or same as k (prefilling).'
|
1074 |
+
|
1075 |
+
if max_seqlen:
|
1076 |
+
assert k_lens.max() <= max_seqlen
|
1077 |
+
|
1078 |
+
n_blocks = (q_lens + block_size - 1) // block_size
|
1079 |
+
|
1080 |
+
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
|
1081 |
+
dtype=cu_seqlens_q.dtype,
|
1082 |
+
device=cu_seqlens_q.device)
|
1083 |
+
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
|
1084 |
+
dtype=cu_seqlens_q.dtype,
|
1085 |
+
device=cu_seqlens_q.device)
|
1086 |
+
|
1087 |
+
|
1088 |
+
out = q.new_empty(q.shape)
|
1089 |
+
cu_seqlens_q = cu_seqlens_q.contiguous()
|
1090 |
+
cu_seqlens_k = cu_seqlens_k.contiguous()
|
1091 |
+
|
1092 |
+
layout_crow_indices, layout_col_indices = sparse_layout
|
1093 |
+
block_d = triton.next_power_of_2(head_size)
|
1094 |
+
|
1095 |
+
decoding_only = (q_lens == 1).all()
|
1096 |
+
|
1097 |
+
grid = (len(q_start_sids), n_heads)
|
1098 |
+
|
1099 |
+
with torch.cuda.device(q.device.index):
|
1100 |
+
_fwd_kernel_batch_inference[grid](
|
1101 |
+
q, k, v, out,
|
1102 |
+
sm_scale,
|
1103 |
+
cu_seqlens_q[:-1],
|
1104 |
+
cu_seqlens_q[1:],
|
1105 |
+
cu_seqlens_k[:-1],
|
1106 |
+
cu_seqlens_k[1:],
|
1107 |
+
q_batch_ids,
|
1108 |
+
q_start_sids,
|
1109 |
+
|
1110 |
+
0, *q.stride(),
|
1111 |
+
0, *k.stride(),
|
1112 |
+
0, *v.stride(),
|
1113 |
+
0, *out.stride(),
|
1114 |
+
|
1115 |
+
layout_crow_indices,
|
1116 |
+
layout_col_indices,
|
1117 |
+
*layout_crow_indices.stride(),
|
1118 |
+
*layout_col_indices.stride(),
|
1119 |
+
|
1120 |
+
q_k_ratio,
|
1121 |
+
HAS_BATCH_DIM = False,
|
1122 |
+
D_HEAD = head_size,
|
1123 |
+
BLOCK_M = block_size,
|
1124 |
+
BLOCK_N = block_size,
|
1125 |
+
BLOCK_D = block_d,
|
1126 |
+
BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
|
1127 |
+
EVEN_D = block_d == head_size,
|
1128 |
+
num_warps = 1 if decoding_only else 4,
|
1129 |
+
num_stages = 3
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
return out
|
1133 |
+
|
1134 |
+
|
1135 |
+
@triton.jit
|
1136 |
+
def _fwd_kernel_inner(
|
1137 |
+
acc, l_i, m_i,
|
1138 |
+
q, Q,
|
1139 |
+
k_block_col_idx,
|
1140 |
+
layout_col_ptr,
|
1141 |
+
layout_col_stride_h, layout_col_stride_m,
|
1142 |
+
k_ptrs,
|
1143 |
+
v_ptrs,
|
1144 |
+
off_h, offs_m, offs_n, offs_d,
|
1145 |
+
stride_kt, stride_vt,
|
1146 |
+
sm_scale,
|
1147 |
+
k_seqlen,
|
1148 |
+
past_len,
|
1149 |
+
LAST_K_BLOCK: tl.constexpr,
|
1150 |
+
BLOCK_M_LOADING: tl.constexpr,
|
1151 |
+
BLOCK_N: tl.constexpr,
|
1152 |
+
D_HEAD: tl.constexpr,
|
1153 |
+
EVEN_D: tl.constexpr,
|
1154 |
+
M_LT_N: tl.constexpr
|
1155 |
+
):
|
1156 |
+
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32)
|
1157 |
+
start_n = k_block_id * BLOCK_N
|
1158 |
+
# -- compute qk ----
|
1159 |
+
if LAST_K_BLOCK:
|
1160 |
+
if EVEN_D:
|
1161 |
+
k = tl.load(k_ptrs + start_n * stride_kt,
|
1162 |
+
mask=offs_n[None, :] + start_n < k_seqlen)
|
1163 |
+
else:
|
1164 |
+
# mask = mask & (offs_d[:, ])
|
1165 |
+
k = tl.load(k_ptrs + start_n * stride_kt,
|
1166 |
+
mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD))
|
1167 |
+
else:
|
1168 |
+
if EVEN_D:
|
1169 |
+
k = tl.load(k_ptrs + start_n * stride_kt)
|
1170 |
+
else:
|
1171 |
+
k = tl.load(k_ptrs + start_n * stride_kt,
|
1172 |
+
mask=offs_d[:, None] < D_HEAD)
|
1173 |
+
|
1174 |
+
|
1175 |
+
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
1176 |
+
qk += tl.dot(q, k)
|
1177 |
+
|
1178 |
+
qk *= sm_scale
|
1179 |
+
|
1180 |
+
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
|
1181 |
+
if LAST_K_BLOCK | M_LT_N:
|
1182 |
+
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf'))
|
1183 |
+
|
1184 |
+
# -- compute m_ij, p, l_ij
|
1185 |
+
m_ij = tl.max(qk, 1)
|
1186 |
+
p = tl.exp(qk - m_ij[:, None])
|
1187 |
+
|
1188 |
+
l_ij = tl.sum(p, 1)
|
1189 |
+
# -- update m_i and l_i
|
1190 |
+
m_i_new = tl.maximum(m_i, m_ij)
|
1191 |
+
alpha = tl.exp(m_i - m_i_new)
|
1192 |
+
beta = tl.exp(m_ij - m_i_new)
|
1193 |
+
l_i_new = alpha * l_i + beta * l_ij
|
1194 |
+
# -- update output accumulator --
|
1195 |
+
# scale p
|
1196 |
+
p_scale = beta / l_i_new
|
1197 |
+
p = p * p_scale[:, None]
|
1198 |
+
# scale acc
|
1199 |
+
acc_scale = l_i / l_i_new * alpha
|
1200 |
+
acc = acc * acc_scale[:, None]
|
1201 |
+
|
1202 |
+
p = p.to(Q.dtype.element_ty)
|
1203 |
+
# update acc
|
1204 |
+
if LAST_K_BLOCK:
|
1205 |
+
if EVEN_D:
|
1206 |
+
v = tl.load(v_ptrs + start_n * stride_vt,
|
1207 |
+
mask=offs_n[:, None] + start_n < k_seqlen)
|
1208 |
+
else:
|
1209 |
+
v = tl.load(v_ptrs + start_n * stride_vt,
|
1210 |
+
mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD))
|
1211 |
+
else:
|
1212 |
+
if EVEN_D:
|
1213 |
+
v = tl.load(v_ptrs + start_n * stride_vt)
|
1214 |
+
else:
|
1215 |
+
v = tl.load(v_ptrs + start_n * stride_vt,
|
1216 |
+
mask=offs_d[None, :] < D_HEAD)
|
1217 |
+
|
1218 |
+
acc += tl.dot(p, v)
|
1219 |
+
# update m_i and l_i
|
1220 |
+
l_i = l_i_new
|
1221 |
+
m_i = m_i_new
|
1222 |
+
return acc, l_i, m_i
|
1223 |
+
|
1224 |
+
|
1225 |
+
@triton.heuristics(
|
1226 |
+
{
|
1227 |
+
'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'],
|
1228 |
+
}
|
1229 |
+
)
|
1230 |
+
@triton.jit
|
1231 |
+
def _fwd_kernel_batch_inference(
|
1232 |
+
Q, K, V, Out,
|
1233 |
+
|
1234 |
+
sm_scale,
|
1235 |
+
q_batch_starts,
|
1236 |
+
q_batch_ends,
|
1237 |
+
k_batch_starts,
|
1238 |
+
k_batch_ends,
|
1239 |
+
q_batch_ids,
|
1240 |
+
q_start_sids,
|
1241 |
+
|
1242 |
+
stride_qb, stride_qt, stride_qh, stride_qd,
|
1243 |
+
stride_kb, stride_kt, stride_kh, stride_kd,
|
1244 |
+
stride_vb, stride_vt, stride_vh, stride_vd,
|
1245 |
+
stride_ob, stride_ot, stride_oh, stride_od,
|
1246 |
+
|
1247 |
+
layout_crow_ptr,
|
1248 |
+
layout_col_ptr,
|
1249 |
+
layout_crow_stride_h, layout_crow_stride_m,
|
1250 |
+
layout_col_stride_h, layout_col_stride_m,
|
1251 |
+
|
1252 |
+
q_k_ratio,
|
1253 |
+
|
1254 |
+
HAS_BATCH_DIM: tl.constexpr,
|
1255 |
+
D_HEAD: tl.constexpr,
|
1256 |
+
BLOCK_M: tl.constexpr,
|
1257 |
+
BLOCK_N: tl.constexpr,
|
1258 |
+
BLOCK_D: tl.constexpr,
|
1259 |
+
BLOCK_M_LOADING: tl.constexpr,
|
1260 |
+
EVEN_D: tl.constexpr,
|
1261 |
+
M_LT_N: tl.constexpr
|
1262 |
+
):
|
1263 |
+
'''
|
1264 |
+
NOTATION:
|
1265 |
+
pid: position id
|
1266 |
+
sid: storage id
|
1267 |
+
sbid: storage block id
|
1268 |
+
pbid: position block id
|
1269 |
+
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
1270 |
+
|
1271 |
+
q and blocks in KV needs to be contiguous
|
1272 |
+
|
1273 |
+
Arguments:
|
1274 |
+
kv_seq_lens: for compute past_len
|
1275 |
+
kv_storage_offsets: similar to block_tables in vllm, except it is dynamic.
|
1276 |
+
TODO: fix this
|
1277 |
+
|
1278 |
+
TODO:
|
1279 |
+
Optimize grouped-attn
|
1280 |
+
|
1281 |
+
CUDA graph support issue
|
1282 |
+
1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...)
|
1283 |
+
since we mix prompt and decoing phase here, it can be more complex.
|
1284 |
+
need to set up diff cuda-graph for diff (off_zm, off_z)
|
1285 |
+
|
1286 |
+
# indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding
|
1287 |
+
therefore, cu_seqlens_q, kv_seq_lens
|
1288 |
+
|
1289 |
+
'''
|
1290 |
+
off_zm = tl.program_id(0)
|
1291 |
+
off_h = tl.program_id(1)
|
1292 |
+
|
1293 |
+
off_h_for_kv = off_h // q_k_ratio
|
1294 |
+
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
|
1295 |
+
q_start_sid = tl.load(q_start_sids + off_zm)
|
1296 |
+
start_m = q_start_sid // BLOCK_M
|
1297 |
+
|
1298 |
+
if HAS_BATCH_DIM:
|
1299 |
+
Q += off_z * stride_qb
|
1300 |
+
K += off_z * stride_kb
|
1301 |
+
V += off_z * stride_vb
|
1302 |
+
Out += off_z * stride_ob
|
1303 |
+
|
1304 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
1305 |
+
offs_n = tl.arange(0, BLOCK_N)
|
1306 |
+
offs_d = tl.arange(0, BLOCK_D)
|
1307 |
+
|
1308 |
+
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
|
1309 |
+
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
|
1310 |
+
|
1311 |
+
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
|
1312 |
+
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
|
1313 |
+
|
1314 |
+
past_len = k_seqlen - q_seqlen
|
1315 |
+
|
1316 |
+
Q += q_cu_start * stride_qt + off_h * stride_qh
|
1317 |
+
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
1318 |
+
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
1319 |
+
Out += q_cu_start * stride_ot + off_h * stride_oh
|
1320 |
+
|
1321 |
+
q_pbid = (past_len + q_start_sid) // BLOCK_M
|
1322 |
+
|
1323 |
+
if EVEN_D:
|
1324 |
+
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
1325 |
+
mask=offs_m[:, None] < q_seqlen)
|
1326 |
+
else:
|
1327 |
+
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
1328 |
+
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
1329 |
+
other=0)
|
1330 |
+
|
1331 |
+
sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m
|
1332 |
+
|
1333 |
+
# TODO: load at once, supported in new Triton
|
1334 |
+
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
1335 |
+
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
1336 |
+
|
1337 |
+
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf')
|
1338 |
+
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
1339 |
+
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
1340 |
+
|
1341 |
+
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
|
1342 |
+
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
|
1343 |
+
|
1344 |
+
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
1345 |
+
acc, l_i, m_i = _fwd_kernel_inner(
|
1346 |
+
acc, l_i, m_i,
|
1347 |
+
q, Q,
|
1348 |
+
k_block_col_idx,
|
1349 |
+
layout_col_ptr,
|
1350 |
+
layout_col_stride_h, layout_col_stride_m,
|
1351 |
+
k_ptrs,
|
1352 |
+
v_ptrs,
|
1353 |
+
off_h, offs_m, offs_n, offs_d,
|
1354 |
+
stride_kt, stride_vt,
|
1355 |
+
sm_scale,
|
1356 |
+
k_seqlen,
|
1357 |
+
past_len,
|
1358 |
+
False,
|
1359 |
+
BLOCK_M_LOADING,
|
1360 |
+
BLOCK_N,
|
1361 |
+
D_HEAD,
|
1362 |
+
EVEN_D,
|
1363 |
+
M_LT_N
|
1364 |
+
)
|
1365 |
+
|
1366 |
+
acc, l_i, m_i = _fwd_kernel_inner(
|
1367 |
+
acc, l_i, m_i,
|
1368 |
+
q, Q,
|
1369 |
+
k_block_end - 1,
|
1370 |
+
layout_col_ptr,
|
1371 |
+
layout_col_stride_h, layout_col_stride_m,
|
1372 |
+
k_ptrs,
|
1373 |
+
v_ptrs,
|
1374 |
+
off_h, offs_m, offs_n, offs_d,
|
1375 |
+
stride_kt, stride_vt,
|
1376 |
+
sm_scale,
|
1377 |
+
k_seqlen,
|
1378 |
+
past_len,
|
1379 |
+
True,
|
1380 |
+
BLOCK_M_LOADING,
|
1381 |
+
BLOCK_N,
|
1382 |
+
D_HEAD,
|
1383 |
+
EVEN_D,
|
1384 |
+
M_LT_N
|
1385 |
+
)
|
1386 |
+
|
1387 |
+
# write output
|
1388 |
+
if EVEN_D:
|
1389 |
+
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
|
1390 |
+
mask=offs_m[:, None] < q_seqlen)
|
1391 |
+
else:
|
1392 |
+
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
|
1393 |
+
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD))
|
1394 |
+
|
1395 |
+
|
1396 |
+
###########################################################
|
1397 |
+
###########################################################
|
1398 |
+
|
1399 |
+
###########################################################
|
1400 |
+
################## Testing Utilities ######################
|
1401 |
+
###########################################################
|
1402 |
+
|
1403 |
+
|
1404 |
+
def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None):
|
1405 |
+
'''
|
1406 |
+
q, k, v: shape=(batch, n_heads, seq, dim)
|
1407 |
+
'''
|
1408 |
+
# for verification
|
1409 |
+
if sm_scale is None:
|
1410 |
+
sm_scale = math.sqrt(float(q.size(-1)))
|
1411 |
+
|
1412 |
+
if block_attn_mask is not None:
|
1413 |
+
assert attn_mask is None
|
1414 |
+
outs = []
|
1415 |
+
for s in range(0, q.size(2), block_size):
|
1416 |
+
e = min(s + block_size, q.size(2))
|
1417 |
+
q_block = q[:, :, s:e]
|
1418 |
+
attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale
|
1419 |
+
mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)]
|
1420 |
+
mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device))
|
1421 |
+
mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0)
|
1422 |
+
attn = attn.masked_fill((1 - mask).bool(), float('-inf'))
|
1423 |
+
attn = attn.softmax(-1)
|
1424 |
+
out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e])
|
1425 |
+
outs.append(out)
|
1426 |
+
torch_output = torch.cat(outs, dim=2)
|
1427 |
+
else:
|
1428 |
+
attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale
|
1429 |
+
# import ipdb; ipdb.set_trace()
|
1430 |
+
if attn_mask is not None:
|
1431 |
+
attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf'))
|
1432 |
+
# print(f'> torch attn: {attn.exp().sum(-1)=}')
|
1433 |
+
|
1434 |
+
attn = attn.softmax(-1)
|
1435 |
+
if do is not None:
|
1436 |
+
dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do)
|
1437 |
+
print(f'> torch_attn computed dv: {dv=}')
|
1438 |
+
torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v)
|
1439 |
+
return torch_output
|
1440 |
+
|
1441 |
+
###########################################################
|
1442 |
+
###########################################################
|
1443 |
+
|
1444 |
+
###########################################################
|
1445 |
+
#################### Unit Tests ###########################
|
1446 |
+
###########################################################
|
1447 |
+
|
1448 |
+
|
1449 |
+
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)])
|
1450 |
+
def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True,
|
1451 |
+
sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None):
|
1452 |
+
Q_LEN = Q_LEN or N_CTX
|
1453 |
+
torch.manual_seed(20)
|
1454 |
+
q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
|
1455 |
+
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
|
1456 |
+
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
|
1457 |
+
|
1458 |
+
if sm_scale is None:
|
1459 |
+
sm_scale = 1. / math.sqrt(D_HEAD)
|
1460 |
+
|
1461 |
+
# for debugging
|
1462 |
+
# print(f'>> {q.shape=}, {k.shape=}, {v.shape=}, {homo_head=}, {kernel_block_size=}, {sparse_block_size=}, {local_blocks=}, {vert_stride=}')
|
1463 |
+
sm_scale = 0.0078125
|
1464 |
+
if backward:
|
1465 |
+
q.requires_grad_(), k.requires_grad_(), v.requires_grad_()
|
1466 |
+
|
1467 |
+
# qkv = torch.empty((Z, N_CTX, 3*H*D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)
|
1468 |
+
# q = qkv[..., :H*D_HEAD]
|
1469 |
+
# k = qkv[..., H*D_HEAD:2*H*D_HEAD]
|
1470 |
+
# v = qkv[..., 2*H*D_HEAD:]
|
1471 |
+
# q = q.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
|
1472 |
+
# k = k.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
|
1473 |
+
# v = v.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
|
1474 |
+
|
1475 |
+
# if Q_LEN and Q_LEN < N_CTX:
|
1476 |
+
# q = q[:, :, -Q_LEN:] # .contiguous()
|
1477 |
+
|
1478 |
+
# q = q.requires_grad_()
|
1479 |
+
# k = k.requires_grad_()
|
1480 |
+
# v = v.requires_grad_()
|
1481 |
+
|
1482 |
+
dout = torch.randn_like(q).contiguous()
|
1483 |
+
|
1484 |
+
# dout = torch.eye(N_CTX)[:, :D_HEAD][None, None].expand_as(q).type_as(q).contiguous()
|
1485 |
+
# print(dout)
|
1486 |
+
|
1487 |
+
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size,
|
1488 |
+
local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True)
|
1489 |
+
|
1490 |
+
if sparse_attention_fn is None:
|
1491 |
+
sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX,
|
1492 |
+
sparse_block_size=sparse_block_size,
|
1493 |
+
local_blocks=local_blocks,
|
1494 |
+
vert_stride=vert_stride,
|
1495 |
+
homo_head=homo_head,
|
1496 |
+
device=q.device,
|
1497 |
+
dtype=q.dtype,
|
1498 |
+
kernel_block_size=kernel_block_size)
|
1499 |
+
# reference implementation
|
1500 |
+
ref_out = torch_attention(q, k, v, mask_dense, sm_scale)
|
1501 |
+
|
1502 |
+
# lengths = torch.full((Z,), fill_value=N_CTX, device='cuda')
|
1503 |
+
# cu_seqlens = torch.zeros((Z + 1,), device='cuda', dtype=torch.int32)
|
1504 |
+
# cu_seqlens[1:] = lengths.cumsum(0)
|
1505 |
+
# # qkv = torch.randn((Z * N_CTX, 3, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1506 |
+
|
1507 |
+
# qkv_list = list(map(lambda x: x.permute(0, 2, 1, 3).contiguous().view(Z * N_CTX, 1, H, D_HEAD), [q, k, v]))
|
1508 |
+
# qkv = torch.cat(qkv_list, dim=1)
|
1509 |
+
# ref_out0 = flash_attn_func(qkv, cu_seqlens, dropout_p=0, max_s=N_CTX, softmax_scale=sm_scale, causal=True)
|
1510 |
+
# ref_out = ref_out0.view(Z, N_CTX, H, D_HEAD).permute(0, 2, 1, 3).contiguous()
|
1511 |
+
|
1512 |
+
|
1513 |
+
if backward:
|
1514 |
+
ref_out.backward(dout)
|
1515 |
+
ref_dv, v.grad = v.grad.clone(), None
|
1516 |
+
ref_dk, k.grad = k.grad.clone(), None
|
1517 |
+
ref_dq, q.grad = q.grad.clone(), None
|
1518 |
+
|
1519 |
+
tri_out = sparse_attention_fn(q, k, v, sm_scale)
|
1520 |
+
|
1521 |
+
decimal = 1 if dtype == torch.bfloat16 else 2
|
1522 |
+
assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}'
|
1523 |
+
|
1524 |
+
if backward:
|
1525 |
+
tri_out.backward(dout)
|
1526 |
+
tri_dv, v.grad = v.grad.clone(), None
|
1527 |
+
tri_dk, k.grad = k.grad.clone(), None
|
1528 |
+
tri_dq, q.grad = q.grad.clone(), None
|
1529 |
+
|
1530 |
+
if backward:
|
1531 |
+
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
|
1532 |
+
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
|
1533 |
+
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
|
1534 |
+
|
1535 |
+
print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}')
|
1536 |
+
|
1537 |
+
###########################################################
|
1538 |
+
|
1539 |
+
if __name__ == '__main__':
|
1540 |
+
|
1541 |
+
GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip()
|
1542 |
+
# print(GPU_TYPE)
|
1543 |
+
support_backward = True # 'A100' in GPU_TYPE. Wasn't supportted in consumer A1000.
|
1544 |
+
|
1545 |
+
###############
|
1546 |
+
# benchmarking
|
1547 |
+
|
1548 |
+
HAS_DENSE_TRITON_FLASH = False
|
1549 |
+
# try:
|
1550 |
+
# from triton.ops.flash_attention import attention as triton_attention
|
1551 |
+
# HAS_DENSE_TRITON_FLASH = True
|
1552 |
+
# except:
|
1553 |
+
# HAS_DENSE_TRITON_FLASH = False
|
1554 |
+
# print('> cannot import Trition flash attn')
|
1555 |
+
|
1556 |
+
try:
|
1557 |
+
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func
|
1558 |
+
HAS_FLASH = True
|
1559 |
+
except BaseException:
|
1560 |
+
HAS_FLASH = False
|
1561 |
+
print('> cannot import flash_attn')
|
1562 |
+
|
1563 |
+
|
1564 |
+
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
1565 |
+
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128 # 6.7B model, with 4k len
|
1566 |
+
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 16, 4096, 128 # 204m model
|
1567 |
+
|
1568 |
+
BLOCK_SIZE = 64
|
1569 |
+
LOCAl_BLOCKS = 8 # 4
|
1570 |
+
VERT_STRIDE = 1 # 16 # 8
|
1571 |
+
HOMO_HEAD = False
|
1572 |
+
sparse_type = 'home' if HOMO_HEAD else 'hetero'
|
1573 |
+
dtype = torch.bfloat16
|
1574 |
+
|
1575 |
+
|
1576 |
+
modes = ['fwd', 'bwd'] if support_backward else ['fwd']
|
1577 |
+
|
1578 |
+
configs = [triton.testing.Benchmark(
|
1579 |
+
x_names=['SEQ_LEN'],
|
1580 |
+
x_vals=[2**i for i in range(8, 16)],
|
1581 |
+
line_arg='provider',
|
1582 |
+
line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'],
|
1583 |
+
line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else []) + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'],
|
1584 |
+
styles=[('red', '-'), ('blue', '-'), ('green', '-')],
|
1585 |
+
ylabel='ms',
|
1586 |
+
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}',
|
1587 |
+
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode}
|
1588 |
+
) for mode in modes]
|
1589 |
+
|
1590 |
+
|
1591 |
+
@triton.testing.perf_report(configs)
|
1592 |
+
def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None):
|
1593 |
+
assert mode in ['fwd', 'bwd']
|
1594 |
+
warmup = 25
|
1595 |
+
rep = 100
|
1596 |
+
N_CTX = SEQ_LEN
|
1597 |
+
if provider == 'triton':
|
1598 |
+
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1599 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1600 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1601 |
+
sm_scale = 1.3
|
1602 |
+
fn = lambda: triton_attention(q, k, v, sm_scale)
|
1603 |
+
if mode == 'bwd':
|
1604 |
+
o = fn()
|
1605 |
+
do = torch.randn_like(o)
|
1606 |
+
fn = lambda: o.backward(do, retain_graph=True)
|
1607 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1608 |
+
return ms
|
1609 |
+
if provider == 'triton_sparse':
|
1610 |
+
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1611 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1612 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1613 |
+
sm_scale = 1.3
|
1614 |
+
# q_pos = torch.arange(N_CTX // BLOCK, device='cuda')[:, None]
|
1615 |
+
# k_pos = torch.arange(N_CTX // BLOCK, device='cuda')[None]
|
1616 |
+
# local_blocks = 4 # num_block per attn, block_size is tied to BLOCK
|
1617 |
+
# vert_stride =N_CTX + 1 # 4
|
1618 |
+
# mask_vert_strided = torch.arange(N_CTX // BLOCK, device='cuda') % vert_stride == vert_stride - 1
|
1619 |
+
# mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).type_as(q)
|
1620 |
+
# mask = mask_dense.to_sparse_csr()
|
1621 |
+
# mask_csr, _ = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK, local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=HOMO_HEAD)
|
1622 |
+
|
1623 |
+
if sparse_attention_fn is None:
|
1624 |
+
# sparse_attention_fn = sparse_attention
|
1625 |
+
sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN,
|
1626 |
+
local_blocks=LOCAl_BLOCKS,
|
1627 |
+
vert_stride=VERT_STRIDE,
|
1628 |
+
homo_head=HOMO_HEAD,
|
1629 |
+
sparse_block_size=BLOCK_SIZE,
|
1630 |
+
kernel_block_size=BLOCK_SIZE,
|
1631 |
+
device=q.device)
|
1632 |
+
# sparse_attention_fn = sparse_attention_factory(128, 128, num_warps=8)
|
1633 |
+
|
1634 |
+
# fn = lambda: sparse_attention_fn(q, k, v, mask_csr[0], mask_csr[1], sm_scale)
|
1635 |
+
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
|
1636 |
+
if mode == 'bwd':
|
1637 |
+
o = fn()
|
1638 |
+
do = torch.randn_like(o)
|
1639 |
+
fn = lambda: o.backward(do, retain_graph=True)
|
1640 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1641 |
+
return ms
|
1642 |
+
if provider == 'flash':
|
1643 |
+
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
1644 |
+
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
1645 |
+
cu_seqlens[1:] = lengths.cumsum(0)
|
1646 |
+
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
1647 |
+
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
1648 |
+
if mode == 'bwd':
|
1649 |
+
o = fn()
|
1650 |
+
do = torch.randn_like(o)
|
1651 |
+
fn = lambda: o.backward(do, retain_graph=True)
|
1652 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1653 |
+
return ms
|
1654 |
+
|
1655 |
+
# if provider == 'torch':
|
1656 |
+
# q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1657 |
+
# k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1658 |
+
# v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
|
1659 |
+
# sm_scale = 1.3
|
1660 |
+
# causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(q)
|
1661 |
+
# fn = lambda: torch_attention(q, k, v, causal_mask, sm_scale)
|
1662 |
+
# ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
1663 |
+
# return ms
|
1664 |
+
|
1665 |
+
|
1666 |
+
BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1 # 6.7B model, with 4k len
|
1667 |
+
|
1668 |
+
BLOCK_SIZE = 64
|
1669 |
+
LOCAl_BLOCKS = 8 # 4
|
1670 |
+
VERT_STRIDE = 16 # 8
|
1671 |
+
HOMO_HEAD = False
|
1672 |
+
sparse_type = 'home' if HOMO_HEAD else 'hetero'
|
1673 |
+
dtype = torch.bfloat16
|
1674 |
+
MAX_N_CTX = 8192
|
1675 |
+
|
1676 |
+
configs = [triton.testing.Benchmark(
|
1677 |
+
x_names=['PAST_LEN'],
|
1678 |
+
x_vals=[2**i - 1 for i in range(8, 14)],
|
1679 |
+
line_arg='provider',
|
1680 |
+
line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'],
|
1681 |
+
line_names=['Torch'] + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'],
|
1682 |
+
styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')],
|
1683 |
+
ylabel='ms',
|
1684 |
+
plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}',
|
1685 |
+
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode}
|
1686 |
+
) for mode in ['fwd']]
|
1687 |
+
@triton.testing.perf_report(configs)
|
1688 |
+
def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'):
|
1689 |
+
assert mode in ['fwd']
|
1690 |
+
warmup = 25
|
1691 |
+
rep = 100
|
1692 |
+
N_CTX = PAST_LEN + Q_LEN
|
1693 |
+
if provider == 'torch':
|
1694 |
+
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1695 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1696 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1697 |
+
sm_scale = 1.3
|
1698 |
+
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE,
|
1699 |
+
local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True)
|
1700 |
+
|
1701 |
+
fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048)
|
1702 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1703 |
+
return ms
|
1704 |
+
if provider == 'triton_sparse':
|
1705 |
+
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1706 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1707 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1708 |
+
sm_scale = 1.3
|
1709 |
+
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
|
1710 |
+
local_blocks=LOCAl_BLOCKS,
|
1711 |
+
vert_stride=VERT_STRIDE,
|
1712 |
+
homo_head=HOMO_HEAD,
|
1713 |
+
sparse_block_size=BLOCK_SIZE,
|
1714 |
+
kernel_block_size=BLOCK_SIZE,
|
1715 |
+
device=q.device,
|
1716 |
+
inference=True)
|
1717 |
+
|
1718 |
+
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
|
1719 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1720 |
+
return ms
|
1721 |
+
if provider == 'triton_dense':
|
1722 |
+
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1723 |
+
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1724 |
+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1725 |
+
sm_scale = 1.3
|
1726 |
+
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
|
1727 |
+
local_blocks=1,
|
1728 |
+
vert_stride=1,
|
1729 |
+
homo_head=True,
|
1730 |
+
sparse_block_size=BLOCK_SIZE,
|
1731 |
+
kernel_block_size=BLOCK_SIZE,
|
1732 |
+
device=q.device,
|
1733 |
+
inference=True)
|
1734 |
+
|
1735 |
+
fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
|
1736 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1737 |
+
return ms
|
1738 |
+
if provider == 'flash':
|
1739 |
+
assert Q_LEN == 1
|
1740 |
+
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
1741 |
+
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
1742 |
+
cu_seqlens[1:] = lengths.cumsum(0)
|
1743 |
+
cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32)
|
1744 |
+
|
1745 |
+
# (total_q, nheads, headdim),
|
1746 |
+
q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1747 |
+
k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1748 |
+
v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
|
1749 |
+
|
1750 |
+
fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False)
|
1751 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
1752 |
+
return ms
|
1753 |
+
|
1754 |
+
|
1755 |
+
test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
|
1756 |
+
# bench_flash_attention.run(save_path='.', print_data=True)
|
1757 |
+
|
1758 |
+
bench_flash_attention_inference.run(save_path='.', print_data=True)
|
1759 |
+
exit()
|
1760 |
+
# head_dim=64
|
1761 |
+
test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64,
|
1762 |
+
dtype=torch.bfloat16, homo_head=False, backward=support_backward)
|
1763 |
+
# uneven length, bf16
|
1764 |
+
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128,
|
1765 |
+
kernel_block_size=64, local_blocks=8, vert_stride=8)
|
1766 |
+
test_op(3, 2, 2047, 128, homo_head=False, backward=False)
|
1767 |
+
|
1768 |
+
# diff kernel/sparse block size
|
1769 |
+
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64)
|
1770 |
+
# inference
|
1771 |
+
# test_op(1, 4, 512 + 256, 128, Q_LEN=1, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
|
1772 |
+
|
1773 |
+
# dense flash attn
|
1774 |
+
test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False,
|
1775 |
+
backward=support_backward, local_blocks=1, vert_stride=1)
|
1776 |
+
|
1777 |
+
# fp16
|
1778 |
+
test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
|
1779 |
+
|
1780 |
+
# longer sequence
|
1781 |
+
test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward)
|
1782 |
+
test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
|
1783 |
+
|
1784 |
+
# homo head
|
1785 |
+
test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False)
|
1786 |
+
test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward)
|
1787 |
+
|
1788 |
+
# sparse_attention_fn = sparse_attention_factory(16, 128, num_warps=1, INFERENCE=True)
|
1789 |
+
# test_op(8, 1, 2047, 128, 1, backward=False, sparse_attention_fn=None)
|
1790 |
+
# test_op_inference(3, 2, 2048, 128, 2048)
|
1791 |
+
# test_op_inference(3, 2, 2047, 64, 2047)
|
1792 |
+
# test_op_inference(3, 2, 256, 64, 128)
|
1793 |
+
# test_op_inference(3, 2, 2048, 64, 1)
|
1794 |
+
|
1795 |
+
bench_flash_attention.run(save_path='.', print_data=True)
|
1796 |
+
# bench_flash_attention_inference.run(save_path='.', print_data=True)
|
1797 |
+
|
1798 |
+
# ========================
|
1799 |
+
# Some Benchmark Results #
|
1800 |
+
# ========================
|
1801 |
+
|
1802 |
+
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-fwd
|
1803 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1804 |
+
# 0 256.0 0.057184 0.069646 0.052567
|
1805 |
+
# 1 512.0 0.131688 0.187658 0.110212
|
1806 |
+
# 2 1024.0 0.391844 0.524990 0.247875
|
1807 |
+
# 3 2048.0 1.305190 1.456685 0.596506
|
1808 |
+
# 4 4096.0 4.623019 4.968653 1.600277
|
1809 |
+
# 5 8192.0 17.513062 18.332262 4.802458
|
1810 |
+
# 6 16384.0 68.453377 70.337540 16.052908
|
1811 |
+
# 7 32768.0 270.655487 276.020233 57.938946
|
1812 |
+
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-bwd (num_warp=8):
|
1813 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1814 |
+
# 0 256.0 0.190120 0.150313 0.181451
|
1815 |
+
# 1 512.0 0.406348 0.391767 0.391177
|
1816 |
+
# 2 1024.0 1.029704 1.182967 0.885741
|
1817 |
+
# 3 2048.0 2.985456 3.843399 2.040469
|
1818 |
+
# 4 4096.0 9.808897 13.073701 5.069609
|
1819 |
+
# 5 8192.0 34.995201 47.863808 13.948782
|
1820 |
+
# 6 16384.0 132.740097 182.579193 42.816513
|
1821 |
+
# 7 32768.0 542.223389 714.820618 147.053574
|
1822 |
+
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
|
1823 |
+
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
|
1824 |
+
# 0 256.0 0.050949 0.032357 0.107513
|
1825 |
+
# 1 512.0 0.073624 0.050651 0.199086
|
1826 |
+
# 2 1024.0 0.107472 0.080379 0.245445
|
1827 |
+
# 3 2048.0 0.178423 0.129448 0.338259
|
1828 |
+
# 4 4096.0 0.327647 0.223106 0.517048
|
1829 |
+
# 5 8192.0 0.588423 0.411263 0.884606
|
1830 |
+
# 6 16384.0 1.098898 0.798941 1.611809
|
1831 |
+
# 7 32768.0 2.094537 1.594726 3.044160
|
1832 |
+
|
1833 |
+
|
1834 |
+
# 6.7B
|
1835 |
+
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-fwd:
|
1836 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1837 |
+
# 0 256.0 0.069208 0.082156 0.065097
|
1838 |
+
# 1 512.0 0.138271 0.201393 0.144467
|
1839 |
+
# 2 1024.0 0.391521 0.624614 0.322382
|
1840 |
+
# 3 2048.0 1.268443 2.406325 0.784367
|
1841 |
+
# 4 4096.0 4.455703 9.139097 2.100856
|
1842 |
+
# 5 8192.0 16.764315 35.289600 6.328320
|
1843 |
+
# 6 16384.0 65.221634 138.401794 21.069057
|
1844 |
+
# 7 32768.0 257.251343 548.085754 76.111870
|
1845 |
+
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-bwd:
|
1846 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1847 |
+
# 0 256.0 0.297118 0.266469 0.255255
|
1848 |
+
# 1 512.0 0.672826 0.613685 0.552954
|
1849 |
+
# 2 1024.0 1.718434 1.705066 1.251953
|
1850 |
+
# 3 2048.0 4.936755 5.403875 2.927895
|
1851 |
+
# 4 4096.0 15.911594 18.959362 7.436288
|
1852 |
+
# 5 8192.0 55.357441 70.808578 21.140224
|
1853 |
+
# 6 16384.0 208.188416 273.617920 68.018173
|
1854 |
+
# 7 32768.0 806.037476 1081.453613 218.720261
|
1855 |
+
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
|
1856 |
+
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
|
1857 |
+
# 0 256.0 0.050151 0.032337 0.107593
|
1858 |
+
# 1 512.0 0.073409 0.051737 0.200200
|
1859 |
+
# 2 1024.0 0.107533 0.082099 0.247067
|
1860 |
+
# 3 2048.0 0.177259 0.128891 0.338510
|
1861 |
+
# 4 4096.0 0.325866 0.223621 0.524842
|
1862 |
+
# 5 8192.0 0.586926 0.408913 0.885490
|
1863 |
+
# 6 16384.0 1.100834 0.793277 1.612271
|
1864 |
+
# 7 32768.0 2.098851 1.595831 3.064544
|
1865 |
+
|
1866 |
+
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-fwd:
|
1867 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1868 |
+
# 0 256.0 0.066673 0.082037 0.065085
|
1869 |
+
# 1 512.0 0.137379 0.201880 0.143473
|
1870 |
+
# 2 1024.0 0.390675 0.624234 0.312046
|
1871 |
+
# 3 2048.0 1.267739 2.406950 0.696045
|
1872 |
+
# 4 4096.0 4.445138 9.136333 1.665788
|
1873 |
+
# 5 8192.0 16.768614 35.265533 4.380486
|
1874 |
+
# 6 16384.0 65.235970 138.393600 12.997633
|
1875 |
+
# 7 32768.0 257.317902 550.442993 42.821121
|
1876 |
+
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-bwd:
|
1877 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1878 |
+
# 0 256.0 0.296461 0.266581 0.254022
|
1879 |
+
# 1 512.0 0.671427 0.613643 0.551283
|
1880 |
+
# 2 1024.0 1.719918 1.704295 1.229982
|
1881 |
+
# 3 2048.0 4.945305 5.403364 2.721906
|
1882 |
+
# 4 4096.0 15.934293 18.960999 6.259371
|
1883 |
+
# 5 8192.0 55.406593 70.832130 15.676929
|
1884 |
+
# 6 16384.0 208.750595 275.004425 44.837891
|
1885 |
+
# 7 32768.0 808.057861 1080.647705 141.856766
|
1886 |
+
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert8-hetero:
|
1887 |
+
# PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
|
1888 |
+
# 0 256.0 0.050739 0.032886 0.107837
|
1889 |
+
# 1 512.0 0.073507 0.051996 0.200293
|
1890 |
+
# 2 1024.0 0.106394 0.080679 0.240610
|
1891 |
+
# 3 2048.0 0.177659 0.127660 0.287625
|
1892 |
+
# 4 4096.0 0.326326 0.226971 0.377500
|
1893 |
+
# 5 8192.0 0.586339 0.407367 0.559266
|
1894 |
+
# 6 16384.0 1.102279 0.786221 0.920976
|
1895 |
+
# 7 32768.0 2.097370 1.545090 1.644288
|
1896 |
+
|
1897 |
+
|
1898 |
+
################
|
1899 |
+
##### fp16 #####
|
1900 |
+
################
|
1901 |
+
|
1902 |
+
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
|
1903 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1904 |
+
# 0 256.0 0.032518 0.035472 0.029939
|
1905 |
+
# 1 512.0 0.054266 0.087841 0.054320
|
1906 |
+
# 2 1024.0 0.133447 0.263090 0.102045
|
1907 |
+
# 3 2048.0 0.384615 1.023293 0.201763
|
1908 |
+
# 4 4096.0 1.300890 4.023936 0.449555
|
1909 |
+
# 5 8192.0 4.774144 15.816704 1.150854
|
1910 |
+
# 6 16384.0 18.220032 62.771198 3.356001
|
1911 |
+
# 7 32768.0 71.405571 250.273788 10.976142
|
1912 |
+
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
|
1913 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1914 |
+
# 0 256.0 0.083342 0.069742 0.079496
|
1915 |
+
# 1 512.0 0.159894 0.170995 0.151705
|
1916 |
+
# 2 1024.0 0.386071 0.522407 0.331443
|
1917 |
+
# 3 2048.0 1.067715 1.737333 0.715248
|
1918 |
+
# 4 4096.0 3.382731 6.219520 1.597457
|
1919 |
+
# 5 8192.0 11.857793 23.560448 3.879035
|
1920 |
+
# 6 16384.0 44.422142 91.251709 10.626843
|
1921 |
+
# 7 32768.0 175.011841 359.473145 32.340992
|
1922 |
+
|
1923 |
+
|
1924 |
+
################
|
1925 |
+
##### bf16 #####
|
1926 |
+
################
|
1927 |
+
|
1928 |
+
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
|
1929 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1930 |
+
# 0 256.0 0.037636 0.035902 0.031512
|
1931 |
+
# 1 512.0 0.058591 0.087229 0.058125
|
1932 |
+
# 2 1024.0 0.143337 0.263919 0.108443
|
1933 |
+
# 3 2048.0 0.414458 1.025985 0.214114
|
1934 |
+
# 4 4096.0 1.390841 4.020010 0.480550
|
1935 |
+
# 5 8192.0 5.067938 15.808171 1.230874
|
1936 |
+
# 6 16384.0 19.442280 62.765057 3.597274
|
1937 |
+
# 7 32768.0 75.501572 250.443771 11.768959
|
1938 |
+
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
|
1939 |
+
# SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
|
1940 |
+
# 0 256.0 0.084404 0.070663 0.082613
|
1941 |
+
# 1 512.0 0.161510 0.172882 0.157661
|
1942 |
+
# 2 1024.0 0.388954 0.526047 0.339855
|
1943 |
+
# 3 2048.0 1.075814 1.736057 0.732420
|
1944 |
+
# 4 4096.0 3.401622 6.221376 1.636039
|
1945 |
+
# 5 8192.0 11.915136 23.483391 3.968725
|
1946 |
+
# 6 16384.0 44.660225 91.302910 10.857130
|
1947 |
+
# 7 32768.0 175.038467 359.048187 32.778240
|