phoebeklett commited on
Commit
73e59c4
1 Parent(s): 78df6d3

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration.py +5 -4
  2. modeling.py +4 -3
configuration.py CHANGED
@@ -101,6 +101,7 @@ class ExtendedMptAttentionConfig(PretrainedConfig):
101
  sim_threshold=0.25,
102
  tokenizer_all_special_ids=[0, 50278],
103
  remove_special_ids=False,
 
104
  **kwargs,
105
  ):
106
  super().__init__(**kwargs)
@@ -121,6 +122,7 @@ class ExtendedMptAttentionConfig(PretrainedConfig):
121
  self.sim_threshold = sim_threshold
122
  self.tokenizer_all_special_ids = tokenizer_all_special_ids
123
  self.remove_special_ids = remove_special_ids
 
124
 
125
  if attn_type not in ["multihead_attention", "multiquery_attention"]:
126
  raise ValueError(
@@ -245,7 +247,6 @@ class ExtendedMptConfig(PretrainedConfig):
245
  n_layers: int = 32,
246
  expansion_ratio: int = 4,
247
  max_seq_len_inference: int = 2048,
248
- max_seq_len_train: int = 2048,
249
  vocab_size: int = 50432,
250
  resid_pdrop: float = 0.0,
251
  layer_norm_epsilon: float = 1e-5,
@@ -261,11 +262,12 @@ class ExtendedMptConfig(PretrainedConfig):
261
  use_cache: bool = False,
262
  initializer_range=0.02,
263
  use_external_mind: bool = True,
264
- use_external_mind_by_layer: list[bool] = [True for _ in range(32)],
265
  **kwargs,
266
  ):
267
  if attn_config is None:
268
- self.attn_config = ExtendedMptAttentionConfig()
 
 
269
  elif not isinstance(attn_config, ExtendedMptAttentionConfig):
270
  self.attn_config = ExtendedMptAttentionConfig(**attn_config)
271
  else:
@@ -275,7 +277,6 @@ class ExtendedMptConfig(PretrainedConfig):
275
  self.n_layers = n_layers
276
  self.expansion_ratio = expansion_ratio
277
  self.max_seq_len = max_seq_len_inference
278
- self.max_seq_len_train = max_seq_len_train
279
  self.vocab_size = vocab_size
280
  self.resid_pdrop = resid_pdrop
281
  self.emb_pdrop = emb_pdrop
 
101
  sim_threshold=0.25,
102
  tokenizer_all_special_ids=[0, 50278],
103
  remove_special_ids=False,
104
+ use_external_mind_by_layer: list[bool] = [True for _ in range(32)],
105
  **kwargs,
106
  ):
107
  super().__init__(**kwargs)
 
122
  self.sim_threshold = sim_threshold
123
  self.tokenizer_all_special_ids = tokenizer_all_special_ids
124
  self.remove_special_ids = remove_special_ids
125
+ self.use_external_mind_by_layer = use_external_mind_by_layer
126
 
127
  if attn_type not in ["multihead_attention", "multiquery_attention"]:
128
  raise ValueError(
 
247
  n_layers: int = 32,
248
  expansion_ratio: int = 4,
249
  max_seq_len_inference: int = 2048,
 
250
  vocab_size: int = 50432,
251
  resid_pdrop: float = 0.0,
252
  layer_norm_epsilon: float = 1e-5,
 
262
  use_cache: bool = False,
263
  initializer_range=0.02,
264
  use_external_mind: bool = True,
 
265
  **kwargs,
266
  ):
267
  if attn_config is None:
268
+ self.attn_config = ExtendedMptAttentionConfig(
269
+ use_external_mind_by_layer=[True for _ in range(n_layers)]
270
+ )
271
  elif not isinstance(attn_config, ExtendedMptAttentionConfig):
272
  self.attn_config = ExtendedMptAttentionConfig(**attn_config)
273
  else:
 
277
  self.n_layers = n_layers
278
  self.expansion_ratio = expansion_ratio
279
  self.max_seq_len = max_seq_len_inference
 
280
  self.vocab_size = vocab_size
281
  self.resid_pdrop = resid_pdrop
282
  self.emb_pdrop = emb_pdrop
modeling.py CHANGED
@@ -42,7 +42,7 @@ from transformers.modeling_outputs import (
42
  from transformers.modeling_utils import PreTrainedModel
43
  from transformers.utils import logging
44
 
45
- from .configuration import ExtendedMptConfig
46
 
47
  logger = logging.get_logger(__name__)
48
 
@@ -920,7 +920,7 @@ class ExtendedMptForCausalLM(MptPreTrainedModel):
920
 
921
  _tied_weights_keys = ["lm_head.weight"]
922
 
923
- def __init__(self, config: ExtendedMptConfig, external_memories=None):
924
  super().__init__(config)
925
  self.transformer: ExtendedMptModel = ExtendedMptModel(config)
926
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -1016,8 +1016,9 @@ class ExtendedMptForCausalLM(MptPreTrainedModel):
1016
  if (
1017
  self.memory_ids is not None and self.memories is None
1018
  ):
 
1019
  self.memories = self.generate_cache(
1020
- self.memory_ids, cache_type=self.memory_type
1021
  )
1022
  # EM: Remove special tokens from memory cache
1023
  if self.remove_special_ids:
 
42
  from transformers.modeling_utils import PreTrainedModel
43
  from transformers.utils import logging
44
 
45
+ from emts_clean.src.mpt.configuration import ExtendedMptConfig
46
 
47
  logger = logging.get_logger(__name__)
48
 
 
920
 
921
  _tied_weights_keys = ["lm_head.weight"]
922
 
923
+ def __init__(self, config: ExtendedMptConfig, external_memories:list=None):
924
  super().__init__(config)
925
  self.transformer: ExtendedMptModel = ExtendedMptModel(config)
926
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
1016
  if (
1017
  self.memory_ids is not None and self.memories is None
1018
  ):
1019
+ self.memory_ids = torch.tensor([self.memory_ids], device=self.device) if type(self.memory_ids)==list else self.memory_ids
1020
  self.memories = self.generate_cache(
1021
+ self.memory_ids, cache_type=self.memory_type,
1022
  )
1023
  # EM: Remove special tokens from memory cache
1024
  if self.remove_special_ids: