Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference
sam-mosaic danbider commited on
Commit
a5e85ae
1 Parent(s): 481f74e

LLM-foundry update June 27, 2023 21:20:15 (#50)

Browse files

- LLM-foundry update June 27, 2023 21:20:15 (ea08ffa68efe8d2a3364e4442fa932dff45eea90)


Co-authored-by: Dan Biderman <[email protected]>

Files changed (2) hide show
  1. modeling_mpt.py +7 -2
  2. norm.py +1 -1
modeling_mpt.py CHANGED
@@ -140,7 +140,7 @@ class MPTModel(MPTPreTrainedModel):
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
  if attention_mask is not None:
@@ -156,6 +156,8 @@ class MPTModel(MPTPreTrainedModel):
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
159
  if self.training:
160
  if self.attn_uses_sequence_id and sequence_id is None:
161
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
@@ -225,6 +227,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
225
  super().__init__(config)
226
  if not config.tie_word_embeddings:
227
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
 
228
  self.transformer = MPTModel(config)
229
  for child in self.transformer.children():
230
  if isinstance(child, torch.nn.ModuleList):
@@ -259,9 +262,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
259
  def get_decoder(self):
260
  return self.transformer
261
 
262
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
263
  return_dict = return_dict if return_dict is not None else self.config.return_dict
264
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
265
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
266
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
267
  if self.logit_scale is not None:
 
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
  if attention_mask is not None:
 
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
159
+ if inputs_embeds is not None:
160
+ raise NotImplementedError('inputs_embeds is not implemented for MPT.')
161
  if self.training:
162
  if self.attn_uses_sequence_id and sequence_id is None:
163
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
 
227
  super().__init__(config)
228
  if not config.tie_word_embeddings:
229
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
230
+ print(f'Instantiating an MPTForCausalLM model from {__file__}')
231
  self.transformer = MPTModel(config)
232
  for child in self.transformer.children():
233
  if isinstance(child, torch.nn.ModuleList):
 
262
  def get_decoder(self):
263
  return self.transformer
264
 
265
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):
266
  return_dict = return_dict if return_dict is not None else self.config.return_dict
267
  use_cache = use_cache if use_cache is not None else self.config.use_cache
268
+ if inputs_embeds is not None:
269
+ raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
270
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
271
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
272
  if self.logit_scale is not None:
norm.py CHANGED
@@ -25,7 +25,7 @@ class LPLayerNorm(torch.nn.LayerNorm):
25
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26
 
27
  def rms_norm(x, weight=None, eps=1e-05):
28
- output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29
  if weight is not None:
30
  return output * weight
31
  return output
 
25
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26
 
27
  def rms_norm(x, weight=None, eps=1e-05):
28
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29
  if weight is not None:
30
  return output * weight
31
  return output