File size: 28,328 Bytes
c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 |
# Adapted from https://github.com/mosaicml/llm-foundry
# Classes changed: MultiheadAttention
# Functions changed: scaled_multihead_dot_product_attention, build_alibi_bias, build_attn_bias
# SPDX-License-Identifier: Apache-2.0
"""Attention layers."""
import math
import warnings
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from packaging import version
from torch import nn
from torch.linalg import vector_norm
from llmfoundry.models.layers.norm import LPLayerNorm
from torch.nn import functional as F
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool):
# disable causal when it is not needed
# necessary for flash & triton for generation with kv_cache
if original_is_causal and num_query_tokens != num_key_tokens:
if num_query_tokens != 1:
raise NotImplementedError(
'MPT does not support query and key with different number of tokens, unless number of query tokens is 1.'
)
else:
return False
return original_is_causal
def scaled_multihead_dot_product_attention(
query,
key,
value,
n_heads,
past_key_value=None,
long_range_past_key_value=None,
softmax_scale=None,
attn_bias=None,
attn_bias_ae=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
topk=None,
faiss_indexes=None,
n_layers=None,
current_layer=None,
mask_by_sim=False,
sim_threshold=0.0
):
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
kv_n_heads = 1 if multiquery else n_heads
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
had_kv=False
if past_key_value is not None:
# attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
# kv_cache is therefore stored using that shape.
# attn_impl: torch stores the kv_cache in the ordering which is most advantageous
# for its attn computation ie
# keys are stored as tensors with shape [b, h, d_head, s] and
# values are stored as tensors with shape [b, h, s, d_head]
if len(past_key_value) != 0:
k = torch.cat([past_key_value[0], k], dim=3)
v = torch.cat([past_key_value[1], v], dim=2)
had_kv=True
past_key_value = (k, v)
b, h, s_q, d = q.shape
s_k = k.size(-1)
if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
attn_weight = q.matmul(k) * softmax_scale
if attn_bias is not None:
# clamp to 0 necessary for torch 2.0 compile()
_s_q = max(0, attn_bias.size(2) - s_q)
_s_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if (attn_bias.size(-1) != 1 and
attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
attn_bias.size(-2) != s_q):
raise RuntimeError(
f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
)
attn_weight = attn_weight + attn_bias
if needs_weights: #will return memory indices w/attention weights
reshaped_idx = None
if long_range_past_key_value is not None or faiss_indexes is not None:
if long_range_past_key_value is not None: #manual memories
k_cache, v_cache = long_range_past_key_value
s_cache = k_cache.size(-1)
k_cache = k_cache.to(k.device)
v_cache = v_cache.to(k.device)
q_n = q/vector_norm(q, ord=2, dim=-1, keepdim=True)
k_n = k_cache/vector_norm(k_cache, ord=2, dim=-2, keepdim=True)
sim = q_n.matmul(k_n)
if s_cache<topk:
topk = s_cache #number of tokens in cache < topk
val, idx = torch.topk(sim, k=topk, dim=-1)
reshaped_idx = idx.reshape(b, h, s_q * topk)
selected_k = k_cache.gather(dim=-1, index=reshaped_idx.unsqueeze(-2).expand(-1, -1, d, -1))
selected_v = v_cache.gather(dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, d))
sim_mask = rearrange(~ (val > sim_threshold).bool(), 'b h s i -> b h (s i)').unsqueeze(-2).expand(-1, -1, s_q, -1)
min_val = torch.finfo(selected_k.dtype).min
elif faiss_indexes is not None: #faiss indexes
kn_index, kv_index = faiss_indexes
q_n = q/vector_norm(q, ord=2, dim=-1, keepdim=True)
one_hot_encodings = F.one_hot(torch.arange(0, n_heads*n_layers, device=q.device))*10
q_n = torch.concat([rearrange(q_n, 'b h s d -> b (h s) d', h=n_heads), one_hot_encodings[n_heads*current_layer:n_heads*(current_layer+1)].unsqueeze(0).repeat_interleave(repeats=q.size(-2), dim=-2)], dim=-1).squeeze()
D, I = kn_index.search(q_n.to('cpu').numpy(), k=topk)
selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:d], '(h s) d -> 1 h d s', h=32).to(q.device)
selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,d:], '(h s) d -> 1 h s d', h=32).to(q.device)
s_k_ae = selected_k.size(-1)
s_k += s_k_ae
attn_weight_cache = q.matmul(selected_k) * softmax_scale
if mask_by_sim:
attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, min_val)
if attn_bias_ae is not None: #add alibi bias to memories
_s_q = max(0, attn_bias_ae.size(2) - s_q)
_s_k = max(0, attn_bias_ae.size(3) - s_k_ae)
attn_bias_ae = attn_bias_ae[:, :, _s_q:, _s_k:]
if (attn_bias_ae.size(-1) != 1 and
attn_bias_ae.size(-1) != s_k_ae) or (attn_bias_ae.size(-2) != 1 and
attn_bias_ae.size(-2) != s_q):
raise RuntimeError(
f'attn_bias (shape: {attn_bias_ae.shape}) is expected to broadcast to shape: {attn_weight_cache.shape}.'
)
attn_weight_cache = attn_weight_cache + attn_bias_ae
attn_weight = torch.cat([attn_weight_cache, attn_weight], dim=-1)
v = torch.cat([selected_v, v], dim=-2)
min_val = torch.finfo(q.dtype).min
if key_padding_mask is not None:
if attn_bias is not None:
warnings.warn(
'Propogating key_padding_mask to the attention module ' +\
'and applying it within the attention module can cause ' +\
'unneccessary computation/memory usage. Consider integrating ' +\
'into attn_bias once and passing that to each attention ' +\
'module instead.'
)
attn_weight = attn_weight.masked_fill(
~key_padding_mask.view((b, 1, 1, s_k)), min_val)
def _create_active_externalism_mask(k, s_q, device):
mask = torch.zeros(s_q, s_q * k, device=device, dtype=torch.bool)
for i in range(s_q):
mask[i, i * k : (i + 1) * k] = 1
return ~mask
if is_causal and (not q.size(2) == 1):
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
causal_mask = causal_mask.tril()
causal_mask = causal_mask.to(torch.bool)
causal_mask = ~causal_mask
causal_mask = causal_mask[-s_q:, -s_k:]
if long_range_past_key_value is not None:
mask = _create_active_externalism_mask(k=topk,s_q=s_q, device=attn_weight.device)
s=s_q
if had_kv:
s += (past_key_value[0][0].size(-1) -s_q)
causal_mask = torch.cat([mask, causal_mask[:,-s:]], dim=1)
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
min_val)
attn_weight = torch.softmax(attn_weight, dim=-1)
if dropout_p:
attn_weight = torch.nn.functional.dropout(attn_weight,
p=dropout_p,
training=training,
inplace=True)
out = attn_weight.to(v.dtype).matmul(v)
out = rearrange(out, 'b h s d -> b s (h d)')
if needs_weights:
return out, attn_weight, past_key_value, reshaped_idx
return out, None, past_key_value, None
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
for tensor in tensors:
if tensor.dtype not in valid_dtypes:
raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
if not tensor.is_cuda:
raise TypeError(f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
def flash_attn_fn(
query,
key,
value,
n_heads,
past_key_value=None,
softmax_scale=None,
attn_bias=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
):
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
except:
raise RuntimeError('Please install flash-attn==1.0.3.post0')
check_valid_inputs(query, key, value)
if past_key_value is not None:
if len(past_key_value) != 0:
key = torch.cat([past_key_value[0], key], dim=1)
value = torch.cat([past_key_value[1], value], dim=1)
past_key_value = (key, value)
if attn_bias is not None:
# clamp to 0 necessary for torch 2.0 compile()
_s_q = max(0, attn_bias.size(2) - query.size(1))
_s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if attn_bias is not None:
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
batch_size, seqlen = query.shape[:2]
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -query.size(1):]
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
query, query_padding_mask)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
key, key_padding_mask)
key_unpad = rearrange(key_unpad,
'nnz (h d) -> nnz h d',
h=1 if multiquery else n_heads)
value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
value_unpad = rearrange(value_unpad,
'nnz (h d) -> nnz h d',
h=1 if multiquery else n_heads)
if multiquery:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
key_unpad.size(-1))
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
value_unpad.size(-1))
dropout_p = dropout_p if training else 0.0
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
query_unpad,
key_unpad,
value_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
output = bert_padding.pad_input(
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
seqlen)
return output, None, past_key_value
def triton_flash_attn_fn(
query,
key,
value,
n_heads,
past_key_value=None,
softmax_scale=None,
attn_bias=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
):
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
except:
_installed = False
if version.parse(torch.__version__) < version.parse('2.0.0'):
_installed = True
# if torch1.13.1 revert to using triton flash attn from HazyResearch
# with flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202
try:
from flash_attn.flash_attn_triton import flash_attn_func
except:
_installed = False
if not _installed:
# installing triton-pre-mlir works for both torch1.13.1 and torch2.0+
# default recommendation is to install this variant
raise RuntimeError(
'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU '
'and `pip install .[gpu]` if installing from llm-foundry source or '
'`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` '
'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). '
'Note: (1) requires you have CMake and PyTorch already installed.'
)
check_valid_inputs(query, key, value)
if past_key_value is not None:
if len(past_key_value) != 0:
key = torch.cat([past_key_value[0], key], dim=1)
value = torch.cat([past_key_value[1], value], dim=1)
past_key_value = (key, value)
if attn_bias is not None:
# clamp to 0 necessary for torch 2.0 compile()
_s_q = max(0, attn_bias.size(2) - query.size(1))
_s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if dropout_p:
raise NotImplementedError(
f'Dropout not implemented for attn_impl: triton.')
if needs_weights:
raise NotImplementedError(
f'attn_impl: triton cannot return attn weights.')
if key_padding_mask is not None:
warnings.warn(
'Propagating key_padding_mask to the attention module ' +\
'and applying it within the attention module can cause ' +\
'unnecessary computation/memory usage. Consider integrating ' +\
'into attn_bias once and passing that to each attention ' +\
'module instead.'
)
b_size, s_k = key_padding_mask.shape[:2]
if attn_bias is None:
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
attn_bias = attn_bias.masked_fill(
~key_padding_mask.view((b_size, 1, 1, s_k)),
torch.finfo(query.dtype).min)
query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
value = rearrange(value,
'b s (h d) -> b s h d',
h=1 if multiquery else n_heads)
if multiquery:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal,
softmax_scale)
output = attn_output.view(*attn_output.shape[:2], -1)
return output, None, past_key_value
class MultiheadAttention(nn.Module):
"""Multi-head self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__()
self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln
self.d_model = d_model
self.n_heads = n_heads
self.softmax_scale = softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop
self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
# for param init fn; enables shape based init of fused layers
fuse_splits = (d_model, 2 * d_model)
self.Wqkv._fused = (0, fuse_splits) # type: ignore
if self.qk_ln:
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(self.d_model, device=device)
self.k_ln = layernorm_class(self.d_model, device=device)
if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'triton':
self.attn_fn = triton_flash_attn_fn
if verbose:
warnings.warn(
'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\
'it uses more memory. When training larger models this can trigger ' +\
'alloc retries which hurts performance. If encountered, we recommend ' +\
'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'
)
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available() and verbose:
warnings.warn(
'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\
'`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\
'we recommend using `attn_impl: triton`.'
)
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj._is_residual = True # type: ignore
def forward(
self,
x,
past_key_value=None,
long_range_past_key_value=None,
attn_bias=None,
attn_bias_ae=None,
attention_mask=None,
is_causal=True,
needs_weights=False,
topk=None,
faiss_indexes=None,
n_layers=None,
current_layer=None,
mask_by_sim=None,
sim_threshold=None
):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
query, key, value = qkv.chunk(3, dim=2)
key_padding_mask = attention_mask
if self.qk_ln:
# Applying layernorm to qk
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
context, attn_weights, past_key_value, reshaped_idx = self.attn_fn(
query,
key,
value,
self.n_heads,
past_key_value=past_key_value,
long_range_past_key_value=long_range_past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
attn_bias_ae=attn_bias_ae,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
topk=topk,
faiss_indexes=faiss_indexes,
n_layers=n_layers,
current_layer=current_layer,
mask_by_sim=mask_by_sim,
sim_threshold=sim_threshold
)
return self.out_proj(context), attn_weights, past_key_value, reshaped_idx
class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__()
self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.softmax_scale = softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.head_dim)
self.attn_dropout_p = attn_pdrop
# NOTE: if we ever want to make attn TensorParallel, I'm pretty sure we'll
# want to split Wqkv into Wq and Wkv where Wq can be TensorParallel but
# Wkv shouldn't be TensorParallel
# - vchiley
self.Wqkv = nn.Linear(
d_model,
d_model + 2 * self.head_dim,
device=device,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = (d_model, d_model + self.head_dim)
self.Wqkv._fused = (0, fuse_splits) # type: ignore
if self.qk_ln:
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(d_model, device=device)
self.k_ln = layernorm_class(self.head_dim, device=device)
if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'triton':
self.attn_fn = triton_flash_attn_fn
if verbose:
warnings.warn(
'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\
'it uses more memory. When training larger models this can trigger ' +\
'alloc retries which hurts performance. If encountered, we recommend ' +\
'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'
)
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available() and verbose:
warnings.warn(
'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\
'`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\
'we recommend using `attn_impl: triton`.'
)
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj._is_residual = True # type: ignore
def forward(
self,
x,
past_key_value=None,
attn_bias=None,
attention_mask=None,
is_causal=True,
needs_weights=False,
):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
query, key, value = qkv.split(
[self.d_model, self.head_dim, self.head_dim], dim=2)
key_padding_mask = attention_mask
if self.qk_ln:
# Applying layernorm to qk
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
multiquery=True,
)
return self.out_proj(context), attn_weights, past_key_value
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
use_sequence_id):
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
if alibi:
if (prefix_lm or not causal) or use_sequence_id:
return (1, n_heads, seq_len, seq_len)
return (1, n_heads, 1, seq_len)
elif prefix_lm or use_sequence_id:
return (1, 1, seq_len, seq_len)
return None
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
def build_attn_bias(
attn_impl,
n_heads,
seq_len,
attn_bias=None,
causal=False,
alibi=False,
alibi_bias_max=8,
for_ae=False,
topk=0,
device=None,
dtype=None
):
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
if alibi:
# in place add alibi to attn bias
if attn_bias is not None:
attn_bias = attn_bias.add(
build_alibi_bias(
n_heads,
seq_len,
full=not causal,
alibi_bias_max=alibi_bias_max,
device=device,
dtype=dtype,
for_ae=for_ae,
topk=topk
))
else: #for memories
attn_bias = build_alibi_bias(
n_heads,
seq_len,
full=not causal,
alibi_bias_max=alibi_bias_max,
for_ae=for_ae,
topk=topk)
return attn_bias
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
_n_heads = 2**math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
m = m.mul(alibi_bias_max / _n_heads)
slopes = (1. / torch.pow(2, m))
if _n_heads != n_heads:
# if n_heads is not a power of two,
# Huggingface and FasterTransformer calculate slopes normally,
# then return this strided concatenation of slopes
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
def build_alibi_bias(
n_heads,
seq_len,
full=False,
alibi_bias_max=8,
device=None,
dtype=None,
for_ae=False,
topk=0
):
if not for_ae:
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32,
device=device).view(1, 1, 1, seq_len)
else:
alibi_bias = torch.tensor(-seq_len, dtype=torch.int32,
device=device).repeat(seq_len*topk).view(1, 1, 1, seq_len*(topk))
if full:
# generate 1 x Heads x SeqLen x SeqLen alibi bias mask
# otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
alibi_bias = alibi_bias - torch.arange(
1 - seq_len, 1, dtype=torch.int32, device=device).view(
1, 1, seq_len, 1)
alibi_bias = alibi_bias.abs().mul(-1)
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
alibi_bias = alibi_bias * slopes
return alibi_bias.to(dtype=dtype)
ATTN_CLASS_REGISTRY = {
'multihead_attention': MultiheadAttention,
'multiquery_attention': MultiQueryAttention,
}
|