File size: 6,858 Bytes
c4bfc74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from functools import lru_cache
from transformers.models.llama.modeling_llama import logger
import os
torch_compile_options = {
"epilogue_fusion" : True,
"max_autotune" : True,
"shape_padding" : True,
"trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
"triton.cudagraphs" : False,
}
# Flex Attention supported from torch 2.5 onwards only
try:
from torch.nn.attention.flex_attention import (
flex_attention as _flex_attention,
create_block_mask as _create_block_mask,
)
_flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
HAS_FLEX_ATTENTION = False
except:
HAS_FLEX_ATTENTION = False
pass
if not HAS_FLEX_ATTENTION:
# Logit softcapping
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_groups = self.num_key_value_groups
# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch.matmul(Q, K.transpose(2, 3))
A = t * torch.tanh(A / t) # Logit softcapping
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch.matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass
create_flex_attention_causal_mask = None
create_flex_attention_sliding_window_mask = None
else:
# See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
# for more examples
# BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
import functools, math
def generate_tanh_softcap(t):
def tanh_softcap(x, b, h, q_idx, kv_idx):
return t * torch.tanh(x / t)
return tanh_softcap
pass
def causal_masker(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
pass
@functools.lru_cache
def sliding_window_masker(size = 4096):
def sliding_window(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx <= size
return causal_mask & window_mask
return sliding_window
pass
@functools.lru_cache
def create_block_mask(mask, n = 128):
return _create_block_mask(
mask, 1, 1, n, n,
BLOCK_SIZE = 128,
_compile = True,
)
pass
def create_flex_attention_causal_mask(max_seq_length = 8192):
causal_mask = create_block_mask(causal_masker, max_seq_length)
return causal_mask
pass
def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
sliding_masker = sliding_window_masker(sliding_window)
causal_mask = create_block_mask(sliding_masker, max_seq_length)
return causal_mask
pass
@functools.lru_cache
def flex_attention(s, t):
scale = 1.0 / math.sqrt(s)
score_mod = generate_tanh_softcap(t)
return functools.partial(
_flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
)
pass
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
fx = flex_attention(s, t)
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass
pass
torch_matmul = torch.matmul
torch_tanh = torch.tanh
torch_nn_functional_softmax = torch.nn.functional.softmax
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_groups = self.num_key_value_groups
# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
# s = self.config.hidden_size // self.config.num_attention_heads
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch_matmul(Q, K.transpose(2, 3))
# Logit softcapping
A /= t; torch_tanh(A, out = A); A *= t;
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch_matmul(A, V)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
pass
|