Files changed (4) hide show
  1. attention.py +5 -6
  2. blocks.py +1 -1
  3. configuration.py +5 -0
  4. modeling_mpt.py +15 -19
attention.py CHANGED
@@ -95,10 +95,10 @@ def scaled_multihead_dot_product_attention(
95
  )
96
  attn_weight = attn_weight + attn_bias
97
 
98
- if needs_weights:
99
  reshaped_idx = None
100
  if long_range_past_key_value is not None or faiss_indexes is not None:
101
- if long_range_past_key_value is not None: #manual external memories
102
 
103
  k_cache, v_cache = long_range_past_key_value
104
  s_cache = k_cache.size(-1)
@@ -134,15 +134,14 @@ def scaled_multihead_dot_product_attention(
134
 
135
  selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:d], '(h s) d -> 1 h d s', h=32).to(q.device)
136
  selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,d:], '(h s) d -> 1 h s d', h=32).to(q.device)
137
-
138
  s_k_ae = selected_k.size(-1)
139
  s_k += s_k_ae
140
  attn_weight_cache = q.matmul(selected_k) * softmax_scale
141
  if mask_by_sim:
142
  attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, min_val)
143
 
144
- if attn_bias_ae is not None:
145
- # clamp to 0 necessary for torch 2.0 compile()
146
  _s_q = max(0, attn_bias_ae.size(2) - s_q)
147
  _s_k = max(0, attn_bias_ae.size(3) - s_k_ae)
148
  attn_bias_ae = attn_bias_ae[:, :, _s_q:, _s_k:]
@@ -710,7 +709,7 @@ def build_attn_bias(
710
  for_ae=for_ae,
711
  topk=topk
712
  ))
713
- else:
714
  attn_bias = build_alibi_bias(
715
  n_heads,
716
  seq_len,
 
95
  )
96
  attn_weight = attn_weight + attn_bias
97
 
98
+ if needs_weights: #will return memory indices w/attention weights
99
  reshaped_idx = None
100
  if long_range_past_key_value is not None or faiss_indexes is not None:
101
+ if long_range_past_key_value is not None: #manual memories
102
 
103
  k_cache, v_cache = long_range_past_key_value
104
  s_cache = k_cache.size(-1)
 
134
 
135
  selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:d], '(h s) d -> 1 h d s', h=32).to(q.device)
136
  selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,d:], '(h s) d -> 1 h s d', h=32).to(q.device)
137
+
138
  s_k_ae = selected_k.size(-1)
139
  s_k += s_k_ae
140
  attn_weight_cache = q.matmul(selected_k) * softmax_scale
141
  if mask_by_sim:
142
  attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, min_val)
143
 
144
+ if attn_bias_ae is not None: #add alibi bias to memories
 
145
  _s_q = max(0, attn_bias_ae.size(2) - s_q)
146
  _s_k = max(0, attn_bias_ae.size(3) - s_k_ae)
147
  attn_bias_ae = attn_bias_ae[:, :, _s_q:, _s_k:]
 
709
  for_ae=for_ae,
710
  topk=topk
711
  ))
712
+ else: #for memories
713
  attn_bias = build_alibi_bias(
714
  n_heads,
715
  seq_len,
blocks.py CHANGED
@@ -7,7 +7,7 @@
7
  from typing import Dict, Optional, Tuple
8
  import torch
9
  import torch.nn as nn
10
- from .attention import ATTN_CLASS_REGISTRY
11
  from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
12
 
13
  class MPTMLP(nn.Module):
 
7
  from typing import Dict, Optional, Tuple
8
  import torch
9
  import torch.nn as nn
10
+ from extended_mind_transformers.mpt.attention import ATTN_CLASS_REGISTRY
11
  from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
12
 
13
  class MPTMLP(nn.Module):
configuration.py CHANGED
@@ -165,6 +165,11 @@ class ExtendedMPTConfig(PretrainedConfig):
165
  init_config_defaults,
166
  )
167
 
 
 
 
 
 
168
  if self.d_model % self.n_heads != 0:
169
  raise ValueError('d_model must be divisible by n_heads')
170
  if any(
 
165
  init_config_defaults,
166
  )
167
 
168
+ if self.attn_config['memory_type']=='faiss' and self.attn_config['mask_by_sim'] is True:
169
+ raise ValueError(
170
+ 'mask_by_sim is not supported for faiss memory type.'
171
+ )
172
+
173
  if self.d_model % self.n_heads != 0:
174
  raise ValueError('d_model must be divisible by n_heads')
175
  if any(
modeling_mpt.py CHANGED
@@ -27,10 +27,10 @@ from llmfoundry.models.layers.custom_embedding import SharedEmbedding
27
  from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
28
  from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
29
 
30
- from .configuration import ExtendedMPTConfig
31
- from .attention import attn_bias_shape, build_attn_bias
32
- from .blocks import MPTBlock
33
- from .utils import instantiate_from_config
34
 
35
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
36
 
@@ -111,7 +111,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
111
  causal=self.is_causal,
112
  use_sequence_id=self.attn_uses_sequence_id,
113
  )
114
- self._attn_bias_ae_initialized = False
115
  self.attn_bias_ae = None
116
 
117
  if self.config.no_bias:
@@ -168,7 +168,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
168
  )
169
  self._attn_bias_initialized = True
170
 
171
- if use_active_externalism:
172
  self.attn_bias_ae = build_attn_bias(
173
  self.attn_impl,
174
  self.config.n_heads,
@@ -196,7 +196,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
196
 
197
  attn_bias = self.attn_bias
198
 
199
- if self.attn_bias_ae is not None:
200
  self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
201
  attn_bias_ae = self.attn_bias_ae
202
 
@@ -417,9 +417,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
417
  assert isinstance(self.emb_drop, nn.Module) # pyright
418
  x = self.emb_drop(x_shrunk)
419
 
420
- # self._attn_bias_initialized = False #right now this needs to run each step
421
-
422
- seq_len = S
423
  if past_key_values is not None:
424
  past_position = past_key_values[0][0].size(-1)
425
  seq_len += past_position
@@ -493,7 +491,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
493
  last_hidden_state=x,
494
  past_key_values=past_key_values,
495
  hidden_states=all_hidden_states,
496
- attentions=(all_self_attns, all_idx),
497
  )
498
 
499
  # Param Initialization, needed for device='meta' fast initialization
@@ -598,7 +596,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
598
  use_active_externalism: Optional[bool]=None,
599
  topk:int=None
600
  ):
601
- if self._memories is not None and self.memories is None:
602
  self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
603
 
604
  return_dict = (return_dict
@@ -702,9 +700,8 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
702
  prev_end_loc=0
703
  long_range_past_key_values = None
704
  faiss_indexes= None
705
- for b_idx in range(0, input_ids.size(-1), stride):
706
  end_loc = min(b_idx + max_len, input_ids.size(-1))
707
-
708
  trg_len = end_loc - prev_end_loc
709
  subseq = input_ids[:, b_idx:end_loc].to(self.device)
710
  with torch.no_grad():
@@ -734,7 +731,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
734
  if long_range_past_key_values is not None and faiss_indexes is not None:
735
  raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
736
 
737
- if cache_type=='faiss':
738
  one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
739
  if faiss_indexes is None:
740
  faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
@@ -747,7 +744,6 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
747
  k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
748
  v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
749
  kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
750
-
751
  else:
752
  if long_range_past_key_values is None:
753
  long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
@@ -759,8 +755,8 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
759
  )
760
  for ind, kv in enumerate(long_range_past_key_values)
761
  ]
762
- if long_range_past_key_values is not None:
763
- if long_range_past_key_values[0][0].size(-1) > max_length_cache: #set a limit on manual memory length
764
  long_range_past_key_values = [
765
  (
766
  kv[0][:, :, :, -max_length_cache:],
@@ -816,7 +812,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
816
  'sequence_id': sequence_id,
817
  'past_key_values': past_key_values,
818
  'use_cache': kwargs.get('use_cache', True),
819
- 'use_active_externalism': kwargs.get('use_active_externalism'),
820
  'topk': kwargs.get('topk', None),
821
  }
822
 
 
27
  from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
28
  from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
29
 
30
+ from extended_mind_transformers.mpt.configuration import ExtendedMPTConfig
31
+ from extended_mind_transformers.mpt.attention import attn_bias_shape, build_attn_bias
32
+ from extended_mind_transformers.mpt.blocks import MPTBlock
33
+ from extended_mind_transformers.utils import instantiate_from_config
34
 
35
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
36
 
 
111
  causal=self.is_causal,
112
  use_sequence_id=self.attn_uses_sequence_id,
113
  )
114
+ self._attn_bias_ae_initialized = False #for active externalism
115
  self.attn_bias_ae = None
116
 
117
  if self.config.no_bias:
 
168
  )
169
  self._attn_bias_initialized = True
170
 
171
+ if use_active_externalism: #for active externalism, init every time since seq_len changes
172
  self.attn_bias_ae = build_attn_bias(
173
  self.attn_impl,
174
  self.config.n_heads,
 
196
 
197
  attn_bias = self.attn_bias
198
 
199
+ if self.attn_bias_ae is not None: #for active externalism
200
  self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
201
  attn_bias_ae = self.attn_bias_ae
202
 
 
417
  assert isinstance(self.emb_drop, nn.Module) # pyright
418
  x = self.emb_drop(x_shrunk)
419
 
420
+ seq_len = S #for active externalism
 
 
421
  if past_key_values is not None:
422
  past_position = past_key_values[0][0].size(-1)
423
  seq_len += past_position
 
491
  last_hidden_state=x,
492
  past_key_values=past_key_values,
493
  hidden_states=all_hidden_states,
494
+ attentions=(all_self_attns, all_idx), #return reshaped_idx for active externalism
495
  )
496
 
497
  # Param Initialization, needed for device='meta' fast initialization
 
596
  use_active_externalism: Optional[bool]=None,
597
  topk:int=None
598
  ):
599
+ if self._memories is not None and self.memories is None: #init memories once on first call
600
  self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
601
 
602
  return_dict = (return_dict
 
700
  prev_end_loc=0
701
  long_range_past_key_values = None
702
  faiss_indexes= None
703
+ for b_idx in range(0, input_ids.size(-1), stride): #generate kv-pairs using stride
704
  end_loc = min(b_idx + max_len, input_ids.size(-1))
 
705
  trg_len = end_loc - prev_end_loc
706
  subseq = input_ids[:, b_idx:end_loc].to(self.device)
707
  with torch.no_grad():
 
731
  if long_range_past_key_values is not None and faiss_indexes is not None:
732
  raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
733
 
734
+ if cache_type=='faiss': #add one-hot encoding to match layer, head indices
735
  one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
736
  if faiss_indexes is None:
737
  faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
 
744
  k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
745
  v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
746
  kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
 
747
  else:
748
  if long_range_past_key_values is None:
749
  long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
 
755
  )
756
  for ind, kv in enumerate(long_range_past_key_values)
757
  ]
758
+ if long_range_past_key_values is not None: #set a limit on manual memory length
759
+ if long_range_past_key_values[0][0].size(-1) > max_length_cache:
760
  long_range_past_key_values = [
761
  (
762
  kv[0][:, :, :, -max_length_cache:],
 
812
  'sequence_id': sequence_id,
813
  'past_key_values': past_key_values,
814
  'use_cache': kwargs.get('use_cache', True),
815
+ 'use_active_externalism': kwargs.get('use_active_externalism'), #add a few more kwargs for active externalism
816
  'topk': kwargs.get('topk', None),
817
  }
818