File size: 35,460 Bytes
c2d160f 0d4fdad 8cc3046 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f d93abf8 c2d160f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 |
# Adapted from https://github.com/mosaicml/llm-foundry
# Classes changed: MPTModel, MPTForCausalLM
# SPDX-License-Identifier: Apache-2.0
"""A simple, flexible implementation of a GPT model.
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.linalg import vector_norm
import faiss
from einops import rearrange
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (PreTrainedModel, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
from .configuration import ExtendedMPTConfig
from .attention import attn_bias_shape, build_attn_bias
from .blocks import MPTBlock
from .utils import instantiate_from_config
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class MPTPreTrainedModel(PreTrainedModel):
config_class = ExtendedMPTConfig
base_model_prefix = 'model'
_no_split_modules = ['MPTBlock']
class ExtendedMPTModel(MPTPreTrainedModel):
def __init__(self, config: ExtendedMPTConfig):
config._validate_config()
super().__init__(config)
self.attn_impl = config.attn_config['attn_impl']
self.prefix_lm = config.attn_config['prefix_lm']
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
self.alibi = config.attn_config['alibi']
self.alibi_bias_max = config.attn_config['alibi_bias_max']
self.mask_by_sim = config.attn_config['mask_by_sim']
self.sim_threshold = config.attn_config['sim_threshold']
self.topk = config.attn_config['topk']
self.use_active_externalism = config.attn_config['use_active_externalism']
self.use_active_externalism_by_layer = config.use_active_externalism_by_layer
if config.init_device == 'mixed':
if dist.get_local_rank() == 0:
config.init_device = 'cpu'
else:
config.init_device = 'meta'
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
raise NotImplementedError(
f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).'
)
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
# CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
# both report this helping with stabilizing training
self.embedding_fraction = config.embedding_fraction
self.wte = SharedEmbedding(config.vocab_size,
config.d_model,
device=config.init_device)
if not self.alibi:
self.wpe = torch.nn.Embedding(config.max_seq_len,
config.d_model,
device=config.init_device)
self.emb_drop = nn.Dropout(config.emb_pdrop)
self.blocks = nn.ModuleList([
MPTBlock(
device=config.init_device,
**config.to_dict(),
) for _ in range(config.n_layers)
])
self.norm_f = norm_class(config.d_model, device=config.init_device)
if config.init_device != 'meta':
print(
f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
)
self.apply(self.param_init_fn)
self.is_causal = not self.prefix_lm
# define attn mask
self._attn_bias_initialized = False
self.attn_bias = None
self.attn_bias_shape = attn_bias_shape(
self.attn_impl,
config.n_heads,
config.max_seq_len,
self.alibi,
prefix_lm=self.prefix_lm,
causal=self.is_causal,
use_sequence_id=self.attn_uses_sequence_id,
)
self._attn_bias_ae_initialized = False #for active externalism
self.attn_bias_ae = None
if self.config.no_bias:
for module in self.modules():
if hasattr(module, 'bias') and isinstance(
module.bias, nn.Parameter):
if self.config.verbose:
warnings.warn(
f'Removing bias ({module.bias}) from {module}.')
module.register_parameter('bias', None)
# Print verbose info
if config.verbose and config.verbose > 2:
print(self)
if 'verbose' not in self.config.init_config:
self.config.init_config['verbose'] = self.config.verbose
if self.config.init_config['verbose'] > 1:
init_fn_name = self.config.init_config['name']
warnings.warn(f'Using {init_fn_name} initialization.')
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, value: nn.Embedding):
self.wte = value
@torch.no_grad()
def _attn_bias(
self,
device,
dtype,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
seq_len: Optional[int] = None,
use_active_externalism:bool=None,
topk=None,
):
if not self._attn_bias_initialized:
if self.attn_bias_shape:
self.attn_bias = torch.zeros(self.attn_bias_shape,
device=device,
dtype=dtype)
self.attn_bias = build_attn_bias(
self.attn_impl,
self.config.n_heads,
self.config.max_seq_len,
device=device,
dtype=dtype,
attn_bias = self.attn_bias,
causal=self.is_causal,
alibi=self.alibi,
alibi_bias_max=self.alibi_bias_max
)
self._attn_bias_initialized = True
if use_active_externalism: #for active externalism, init every time since seq_len changes
self.attn_bias_ae = build_attn_bias(
self.attn_impl,
self.config.n_heads,
seq_len,
device=device,
dtype=dtype,
causal=self.is_causal,
alibi=self.alibi,
alibi_bias_max=self.alibi_bias_max,
for_ae=use_active_externalism,
topk=topk
)
self._attn_bias_ae_initialized = True
# flash does not support prefix_lm and will incorporate any
# attention_mask inside the attention module
if self.attn_impl == 'flash':
return self.attn_bias, attention_mask
if self.attn_bias is not None:
# .to(*args, **kwargs) is a no-op if tensor is already on
# specified device or of specificed dtype
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
attn_bias = self.attn_bias
if self.attn_bias_ae is not None: #for active externalism
self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
attn_bias_ae = self.attn_bias_ae
# If using torch or triton, we incorporate the prefix_mask (if appropriate)
if self.prefix_lm:
assert isinstance(attn_bias, torch.Tensor) # pyright
assert isinstance(prefix_mask, torch.Tensor) # pyright
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
# If using torch or triton, we incorporate sequence_id (if appropriate)
if self.attn_uses_sequence_id and sequence_id is not None:
assert isinstance(attn_bias, torch.Tensor) # pyright
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
# If using torch or triton, we incorporate attention_mask. This will output
# None in place of attention_mask since it will not be further needed in the
# attention modules.
if attention_mask is not None:
s_k = attention_mask.shape[-1]
if attn_bias is None:
attn_bias = torch.zeros((1, 1, 1, s_k),
device=device,
dtype=dtype)
else:
# clamp to 0 necessary for torch 2.0 compile()
_s_k = max(0, attn_bias.size(-1) - s_k)
attn_bias = attn_bias[:, :, :, _s_k:]
if prefix_mask is not None and (attention_mask.shape !=
prefix_mask.shape):
raise ValueError(
f'attention_mask shape={attention_mask.shape} ' +
f'and prefix_mask shape={prefix_mask.shape} are not equal.')
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(
~attention_mask.view(-1, 1, 1, s_k), min_val)
return attn_bias, attn_bias_ae, None
def _apply_prefix_mask(self, attn_bias: torch.Tensor,
prefix_mask: torch.Tensor):
s_k, s_q = attn_bias.shape[-2:]
if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
raise ValueError(
'attn_bias does not match the expected shape. ' +
f'The last two dimensions should both be {self.config.max_length} '
+ f'but are {s_k} and {s_q}.')
seq_len = prefix_mask.shape[-1]
if seq_len > self.config.max_seq_len:
raise ValueError(
f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
)
# select seq_len subset of attn mask
attn_bias = attn_bias[..., :seq_len, :seq_len]
# Mix the causal max and the bidirectional mask to get the full
# allowable attention (i.e. full = not accounting for padding yet)
causal = torch.tril(
torch.ones((seq_len, seq_len),
dtype=torch.bool,
device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
prefix = prefix_mask.view(-1, 1, 1, seq_len)
cannot_attend = ~torch.logical_or(causal, prefix.bool())
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
return attn_bias
def _apply_sequence_id(self, attn_bias: torch.Tensor,
sequence_id: torch.LongTensor):
seq_len = sequence_id.shape[-1]
if seq_len > self.config.max_seq_len:
raise ValueError(
f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
)
# select seq_len subset of attn mask
attn_bias = attn_bias[..., :seq_len, :seq_len]
# Restrict attention to tokens that share the same value
# in sequence_id
cannot_attend = torch.logical_not(
torch.eq(
sequence_id.view(-1, seq_len, 1),
sequence_id.view(-1, 1, seq_len),
)).unsqueeze(1)
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
return attn_bias
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
use_cache: Optional[bool] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_active_externalism:Optional[bool]=None,
long_range_past_key_values:Optional[List[Tuple[torch.FloatTensor]]] = None,
faiss_indexes:Tuple=None,
topk:int=None,
):
return_dict = (return_dict
if return_dict is not None else self.config.return_dict)
use_cache = (use_cache
if use_cache is not None else self.config.use_cache)
use_active_externalism = (use_active_externalism
if use_active_externalism is not None else self.use_active_externalism)
topk = (topk if topk is not None else self.topk)
if attention_mask is not None:
attention_mask = attention_mask.bool()
if prefix_mask is not None:
prefix_mask = prefix_mask.bool()
# These args are passed in by keyword in huggingface's generate function
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
# but have not yet been fully implemented in MPTModel
if not return_dict:
raise NotImplementedError(
'return_dict False is not implemented yet for MPT')
if output_attentions:
if self.attn_impl != 'torch':
raise NotImplementedError(
'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
)
if (attention_mask is not None and
attention_mask[:, 0].sum() != attention_mask.shape[0] and
self.training):
raise NotImplementedError(
'MPT does not support training with left padding.')
if self.prefix_lm and prefix_mask is None:
raise ValueError(
'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
)
# Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT.')
if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError(
'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True '
+ 'and the model is in train mode.')
elif (self.attn_uses_sequence_id is False) and (sequence_id
is not None):
warnings.warn(
'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. '
+
'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
)
S = input_ids.size(1)
assert (
S <= self.config.max_seq_len
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
tok_emb = self.wte(input_ids) # type: ignore
if self.alibi:
x = tok_emb
else:
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f'past_key_values must provide a past_key_value for each attention '
+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
)
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
past_position = past_key_values[0][0].size(1)
if self.attn_impl == 'torch':
past_position = past_key_values[0][0].size(3)
if S + past_position > self.config.max_seq_len:
raise ValueError(
f'Cannot forward input with past sequence length {past_position} and current sequence length '
f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
)
pos = torch.arange(
past_position,
S + past_position,
dtype=torch.long,
device=input_ids.device,
).unsqueeze(0)
if attention_mask is not None:
# adjust the position indices to account for padding tokens
pos = torch.clamp(
pos - torch.cumsum((~attention_mask).to(torch.int32),
dim=1)[:, past_position:],
min=0,
)
pos_emb = self.wpe(pos) # type: ignore
x = tok_emb + pos_emb
if self.embedding_fraction == 1:
x = self.emb_drop(x) # type: ignore
else:
# this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
x_shrunk = (x * self.embedding_fraction) + (
x.detach() * (1 - self.embedding_fraction))
assert isinstance(self.emb_drop, nn.Module) # pyright
x = self.emb_drop(x_shrunk)
seq_len = S #for active externalism
if past_key_values is not None:
past_position = past_key_values[0][0].size(-1)
seq_len += past_position
attn_bias, attn_bias_ae, attention_mask = self._attn_bias(
device=x.device,
dtype=torch.float32,
attention_mask=attention_mask,
prefix_mask=prefix_mask,
sequence_id=sequence_id,
seq_len = seq_len,
use_active_externalism=use_active_externalism,
topk=topk
)
# initialize the past key values cache if it should be used
if use_cache and past_key_values is None:
past_key_values = [() for _ in range(self.config.n_layers)
] # type: ignore
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_idx = () if output_attentions else None
for b_idx, block in enumerate(self.blocks): # type: ignore
if output_hidden_states:
assert all_hidden_states is not None # pyright
all_hidden_states = all_hidden_states + (x,)
past_key_value = (past_key_values[b_idx]
if past_key_values is not None else None)
long_range_past_key_value = (long_range_past_key_values[b_idx]
if (long_range_past_key_values is not None and self.use_active_externalism_by_layer[b_idx] and use_active_externalism is True) else None)
if long_range_past_key_value is not None and faiss_indexes is not None:
raise NotImplementedError(
'Using faiss and passing key value pairs manually are mutually exclusive right now.')
x, attn_weights, past_key_value, reshaped_idx = block(
x,
past_key_value=past_key_value,
long_range_past_key_value=long_range_past_key_value,
attn_bias=attn_bias,
attention_mask=attention_mask,
attn_bias_ae=attn_bias_ae,
is_causal=self.is_causal,
topk=topk,
needs_weights=output_attentions,
faiss_indexes=faiss_indexes,
n_layers=self.config.n_layers,
current_layer=b_idx,
mask_by_sim=self.mask_by_sim,
sim_threshold=self.sim_threshold,
)
if past_key_values is not None:
past_key_values[b_idx] = past_key_value
if output_attentions:
assert all_self_attns is not None # pyright
all_self_attns = all_self_attns + (attn_weights,)
assert all_idx is not None
all_idx = all_idx + (reshaped_idx,)
x = self.norm_f(x) # type: ignore
# add hidden states from the last decoder layer
if output_hidden_states:
assert all_hidden_states is not None # pyright
all_hidden_states = all_hidden_states + (x,)
return BaseModelOutputWithPast(
last_hidden_state=x,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=(all_self_attns, all_idx), #return reshaped_idx for active externalism
)
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module):
init_fn_name = self.config.init_config['name']
MODEL_INIT_REGISTRY[init_fn_name](
module=module,
n_layers=self.config.n_layers,
d_model=self.config.d_model,
**self.config.init_config,
)
# FSDP Wrap function
def fsdp_wrap_fn(self, module):
return isinstance(module, MPTBlock)
# Activation Checkpointing
def activation_checkpointing_fn(self, module):
return isinstance(module, MPTBlock)
class ExtendedMPTForCausalLM(MPTPreTrainedModel):
def __init__(self, config:ExtendedMPTConfig, external_memories=None):
if isinstance(config, DictConfig):
config = instantiate_from_config(config)
super().__init__(config)
if not config.tie_word_embeddings:
raise ValueError(
'MPTForCausalLM only supports tied word embeddings')
print(f'Instantiating an MPTForCausalLM model from {__file__}')
self.transformer: ExtendedMPTModel = ExtendedMPTModel(config)
self.use_active_externalism = config.attn_config['use_active_externalism']
self.memory_type = config.attn_config['memory_type']
self._memories = None
self.memory_device = config.memory_device
for child in self.transformer.children():
if isinstance(child, torch.nn.ModuleList):
continue
if isinstance(child, torch.nn.Module):
child._fsdp_wrap = True
# enables scaling output logits; similar to a softmax "temperature"
# PaLM paper uses scale 1/sqrt(config.d_model)
self.logit_scale = None
if config.logit_scale is not None:
logit_scale = config.logit_scale
if isinstance(logit_scale, str):
if logit_scale == 'inv_sqrt_d_model':
logit_scale = 1 / math.sqrt(config.d_model)
else:
raise ValueError(
f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
)
self.logit_scale = logit_scale
if external_memories is not None:
self._memories = external_memories
self.memories = None
def set_memories(self, memories):
self.memories = memories
def empty_memories(self):
self.memories = None
def get_input_embeddings(self):
return self.transformer.wte
def set_input_embeddings(self, value):
self.transformer.wte = value
def get_output_embeddings(self):
return self.transformer.wte
def set_output_embeddings(self, new_embeddings):
self.transformer.wte = new_embeddings
def set_decoder(self, decoder):
self.transformer = decoder
def get_decoder(self):
return self.transformer
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
prefix_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
use_cache: Optional[bool] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_active_externalism: Optional[bool]=None,
topk:int=None
):
if self._memories is not None and self.memories is None: #init memories once on first call
self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
return_dict = (return_dict
if return_dict is not None else self.config.return_dict)
use_cache = (use_cache
if use_cache is not None else self.config.use_cache)
use_active_externalism = (use_active_externalism
if use_active_externalism is not None else self.use_active_externalism)
topk = topk if topk is not None else None
# if input_embeds is not none, raise a not implemented error
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds has to be None (for hf/peft support).')
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
if hasattr(self, "memories") and type(self.memories)==list:
long_range_past_key_values = self.memories
faiss_indexes = None
elif hasattr(self, "memories"):
long_range_past_key_values = None
faiss_indexes = self.memories
else:
long_range_past_key_values = None
faiss_indexes = None
outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
long_range_past_key_values=long_range_past_key_values,
faiss_indexes=faiss_indexes,
attention_mask=attention_mask,
prefix_mask=prefix_mask,
sequence_id=sequence_id,
return_dict=return_dict,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
use_active_externalism=use_active_externalism,
topk=topk
)
# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
logits = self.transformer.wte(
outputs.last_hidden_state.to(self.transformer.wte.weight.device),
True,
)
if self.logit_scale is not None:
if self.logit_scale == 0:
warnings.warn(
f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.'
)
logits *= self.logit_scale
loss = None
if labels is not None:
_labels = torch.roll(labels, shifts=-1)
_labels[:, -1] = -100
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
_labels.to(logits.device).view(-1),
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module):
init_fn_name = self.config.init_config['name']
MODEL_INIT_REGISTRY[init_fn_name](
module=module,
n_layers=self.config.n_layers,
d_model=self.config.d_model,
**self.config.init_config,
)
# FSDP Wrap function
def fsdp_wrap_fn(self, module):
return isinstance(module, MPTBlock)
# Activation Checkpointing
def activation_checkpointing_fn(self, module):
return isinstance(module, MPTBlock)
def generate_cache(self,
input_ids:torch.LongTensor,
stride:int=512,
max_len:int=2048,
cache_type:str='manual'):
if cache_type not in ['manual', 'faiss']:
raise NotImplementedError(f"Cache type {cache_type} not implemented.")
prev_end_loc=0
long_range_past_key_values = None
faiss_indexes= None
for b_idx in range(0, input_ids.size(-1), stride): #generate kv-pairs using stride
end_loc = min(b_idx + max_len, input_ids.size(-1))
trg_len = end_loc - prev_end_loc
subseq = input_ids[:, b_idx:end_loc].to(self.device)
with torch.no_grad():
outputs = self.transformer(subseq, use_cache=True, use_active_externalism=False)
to_cache = [(
kv[0][:,:,:,-trg_len:],
kv[1][:,:,-trg_len:])
for kv in outputs.past_key_values
]
long_range_past_key_values, faiss_indexes = self.cache(to_cache, cache_type, long_range_past_key_values=long_range_past_key_values, faiss_indexes=faiss_indexes)
prev_end_loc = end_loc
if end_loc == input_ids.size(-1):
break
if long_range_past_key_values is not None:
return long_range_past_key_values
else:
return faiss_indexes
def cache(self,
to_cache:List,
cache_type:str='manual',
long_range_past_key_values:List=None,
faiss_indexes:faiss.IndexFlatIP=None,
max_length_cache=100000,
verbose=False):
if long_range_past_key_values is not None and faiss_indexes is not None:
raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
if cache_type=='faiss': #add one-hot encoding to match layer, head indices
one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
if faiss_indexes is None:
faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
kn_index, kv_index = faiss_indexes
for b_idx, (k, v) in enumerate(to_cache):
k_n = (k/vector_norm(k, ord=2, dim=-2, keepdim=True)).to('cpu')
k_n = torch.concat([rearrange(k_n, 'b h d s -> b (h s) d', h=self.config.n_heads), one_hot_encodings[self.config.n_heads*b_idx:self.config.n_heads*(b_idx+1)].unsqueeze(0).repeat_interleave(repeats=k.size(-1), dim=-2)], dim=-1)
kn_index.add(k_n.squeeze().numpy())
k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
else:
if long_range_past_key_values is None:
long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
else:
long_range_past_key_values = [
(
torch.concat([kv[0], to_cache[ind][0].to(self.memory_device)], dim=3),
torch.concat([kv[1], to_cache[ind][1].to(self.memory_device)], dim=2)
)
for ind, kv in enumerate(long_range_past_key_values)
]
if long_range_past_key_values is not None: #set a limit on manual memory length
if long_range_past_key_values[0][0].size(-1) > max_length_cache:
long_range_past_key_values = [
(
kv[0][:, :, :, -max_length_cache:],
kv[1][:, :, -max_length_cache:]
)
for kv in long_range_past_key_values]
if verbose:
if cache_type == 'faiss':
print(f"{kn_index.ntotal} keys in faiss index")
else:
print(f"{long_range_past_key_values[0][0].size(-1)} cached kvs")
return long_range_past_key_values, (kn_index, kv_index) if cache_type == 'faiss' else None
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
**kwargs,
):
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT yet')
attention_mask = kwargs['attention_mask'].bool()
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
raise NotImplementedError(
'MPT does not support generation with right padding.')
if self.transformer.attn_uses_sequence_id and self.training:
sequence_id = torch.zeros_like(input_ids[:1])
else:
sequence_id = None
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if self.transformer.prefix_lm:
# Leverage a convenience of sequential generation!
prefix_mask = torch.ones_like(attention_mask)
# This requires that we're using the cache
if kwargs.get('use_cache') == False:
raise NotImplementedError(
'MPT with prefix_lm=True does not support use_cache=False.')
else:
prefix_mask = None
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'prefix_mask': prefix_mask,
'sequence_id': sequence_id,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache', True),
'use_active_externalism': kwargs.get('use_active_externalism'), #add a few more kwargs for active externalism
'topk': kwargs.get('topk', None),
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""Used by HuggingFace generate when using beam search with kv-caching.
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
for an example in transformers.
"""
reordered_past = []
for layer_past in past_key_values:
reordered_past += [
tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past)
]
return reordered_past |