sam-mosaic
commited on
Commit
·
e913229
1
Parent(s):
996ffc5
Upload folder using huggingface_hub
Browse files- attention.py +1 -1
- config.json +3 -3
- generation_config.json +1 -1
- modeling_mpt.py +13 -8
- norm.py +3 -2
- pytorch_model-00001-of-00002.bin +1 -1
- pytorch_model-00002-of-00002.bin +1 -1
attention.py
CHANGED
@@ -46,7 +46,7 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
|
|
46 |
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
47 |
if is_causal and (not q.size(2) == 1):
|
48 |
s = max(s_q, s_k)
|
49 |
-
causal_mask = attn_weight.new_ones(s, s, dtype=torch.
|
50 |
causal_mask = causal_mask.tril()
|
51 |
causal_mask = causal_mask.to(torch.bool)
|
52 |
causal_mask = ~causal_mask
|
|
|
46 |
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
47 |
if is_causal and (not q.size(2) == 1):
|
48 |
s = max(s_q, s_k)
|
49 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
|
50 |
causal_mask = causal_mask.tril()
|
51 |
causal_mask = causal_mask.to(torch.bool)
|
52 |
causal_mask = ~causal_mask
|
config.json
CHANGED
@@ -27,9 +27,9 @@
|
|
27 |
"emb_init_uniform_lim": null,
|
28 |
"fan_mode": "fan_in",
|
29 |
"init_div_is_residual": true,
|
30 |
-
"init_gain": 0,
|
31 |
"init_nonlinearity": "relu",
|
32 |
-
"init_std":
|
33 |
"name": "kaiming_normal_",
|
34 |
"verbose": 0
|
35 |
},
|
@@ -45,7 +45,7 @@
|
|
45 |
"resid_pdrop": 0,
|
46 |
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
47 |
"torch_dtype": "bfloat16",
|
48 |
-
"transformers_version": "4.
|
49 |
"use_cache": false,
|
50 |
"verbose": 0,
|
51 |
"vocab_size": 50432
|
|
|
27 |
"emb_init_uniform_lim": null,
|
28 |
"fan_mode": "fan_in",
|
29 |
"init_div_is_residual": true,
|
30 |
+
"init_gain": 0.0,
|
31 |
"init_nonlinearity": "relu",
|
32 |
+
"init_std": null,
|
33 |
"name": "kaiming_normal_",
|
34 |
"verbose": 0
|
35 |
},
|
|
|
45 |
"resid_pdrop": 0,
|
46 |
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
47 |
"torch_dtype": "bfloat16",
|
48 |
+
"transformers_version": "4.30.2",
|
49 |
"use_cache": false,
|
50 |
"verbose": 0,
|
51 |
"vocab_size": 50432
|
generation_config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
"_from_model_config": true,
|
3 |
-
"transformers_version": "4.
|
4 |
"use_cache": false
|
5 |
}
|
|
|
1 |
{
|
2 |
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.30.2",
|
4 |
"use_cache": false
|
5 |
}
|
modeling_mpt.py
CHANGED
@@ -18,7 +18,7 @@ from .configuration_mpt import MPTConfig
|
|
18 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
19 |
from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
|
20 |
from .meta_init_context import init_empty_weights
|
21 |
-
from .param_init_fns import
|
22 |
try:
|
23 |
from .flash_attn_triton import flash_attn_func
|
24 |
except:
|
@@ -80,7 +80,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
80 |
def get_input_embeddings(self):
|
81 |
return self.wte
|
82 |
|
83 |
-
def set_input_embeddings(self, value):
|
84 |
self.wte = value
|
85 |
|
86 |
@torch.no_grad()
|
@@ -140,7 +140,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
140 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
141 |
return attn_bias
|
142 |
|
143 |
-
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
144 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
145 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
146 |
if attention_mask is not None:
|
@@ -156,6 +156,8 @@ class MPTModel(MPTPreTrainedModel):
|
|
156 |
raise NotImplementedError('MPT does not support training with left padding.')
|
157 |
if self.prefix_lm and prefix_mask is None:
|
158 |
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
|
|
|
|
|
159 |
if self.training:
|
160 |
if self.attn_uses_sequence_id and sequence_id is None:
|
161 |
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
|
@@ -225,7 +227,8 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
225 |
super().__init__(config)
|
226 |
if not config.tie_word_embeddings:
|
227 |
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
228 |
-
|
|
|
229 |
for child in self.transformer.children():
|
230 |
if isinstance(child, torch.nn.ModuleList):
|
231 |
continue
|
@@ -259,9 +262,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
259 |
def get_decoder(self):
|
260 |
return self.transformer
|
261 |
|
262 |
-
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
263 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
264 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
|
265 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
266 |
logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
|
267 |
if self.logit_scale is not None:
|
@@ -270,9 +275,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
270 |
logits *= self.logit_scale
|
271 |
loss = None
|
272 |
if labels is not None:
|
273 |
-
|
274 |
-
|
275 |
-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
|
276 |
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
277 |
|
278 |
def param_init_fn(self, module):
|
|
|
18 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
19 |
from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
|
20 |
from .meta_init_context import init_empty_weights
|
21 |
+
from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
|
22 |
try:
|
23 |
from .flash_attn_triton import flash_attn_func
|
24 |
except:
|
|
|
80 |
def get_input_embeddings(self):
|
81 |
return self.wte
|
82 |
|
83 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
84 |
self.wte = value
|
85 |
|
86 |
@torch.no_grad()
|
|
|
140 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
141 |
return attn_bias
|
142 |
|
143 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
|
144 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
145 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
146 |
if attention_mask is not None:
|
|
|
156 |
raise NotImplementedError('MPT does not support training with left padding.')
|
157 |
if self.prefix_lm and prefix_mask is None:
|
158 |
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
|
159 |
+
if inputs_embeds is not None:
|
160 |
+
raise NotImplementedError('inputs_embeds is not implemented for MPT.')
|
161 |
if self.training:
|
162 |
if self.attn_uses_sequence_id and sequence_id is None:
|
163 |
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
|
|
|
227 |
super().__init__(config)
|
228 |
if not config.tie_word_embeddings:
|
229 |
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
230 |
+
print(f'Instantiating an MPTForCausalLM model from {__file__}')
|
231 |
+
self.transformer: MPTModel = MPTModel(config)
|
232 |
for child in self.transformer.children():
|
233 |
if isinstance(child, torch.nn.ModuleList):
|
234 |
continue
|
|
|
262 |
def get_decoder(self):
|
263 |
return self.transformer
|
264 |
|
265 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):
|
266 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
267 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
268 |
+
if inputs_embeds is not None:
|
269 |
+
raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
|
270 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
271 |
logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
|
272 |
if self.logit_scale is not None:
|
|
|
275 |
logits *= self.logit_scale
|
276 |
loss = None
|
277 |
if labels is not None:
|
278 |
+
_labels = torch.roll(labels, shifts=-1)
|
279 |
+
_labels[:, -1] = -100
|
280 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
|
281 |
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
282 |
|
283 |
def param_init_fn(self, module):
|
norm.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
|
3 |
def _cast_if_autocast_enabled(tensor):
|
@@ -25,7 +26,7 @@ class LPLayerNorm(torch.nn.LayerNorm):
|
|
25 |
return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
|
26 |
|
27 |
def rms_norm(x, weight=None, eps=1e-05):
|
28 |
-
output = x
|
29 |
if weight is not None:
|
30 |
return output * weight
|
31 |
return output
|
@@ -53,4 +54,4 @@ class LPRMSNorm(RMSNorm):
|
|
53 |
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
54 |
with torch.autocast(enabled=False, device_type=x.device.type):
|
55 |
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
56 |
-
NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
|
|
|
1 |
+
from typing import Dict, Type
|
2 |
import torch
|
3 |
|
4 |
def _cast_if_autocast_enabled(tensor):
|
|
|
26 |
return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
|
27 |
|
28 |
def rms_norm(x, weight=None, eps=1e-05):
|
29 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
30 |
if weight is not None:
|
31 |
return output * weight
|
32 |
return output
|
|
|
54 |
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
55 |
with torch.autocast(enabled=False, device_type=x.device.type):
|
56 |
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
57 |
+
NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
|
pytorch_model-00001-of-00002.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 9943040275
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6003cd1c33b5a661320c11225b54fb0cdfd931f73241ed810c57dc9e32163146
|
3 |
size 9943040275
|
pytorch_model-00002-of-00002.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 3355599187
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:234b5d739ed88a00dcf1e28932158157418d386837d2345f0ec8a0b218e7d823
|
3 |
size 3355599187
|