leemeng commited on
Commit
fb2a59a
·
1 Parent(s): aaeff91

update for eval

Browse files
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/japanese-stablelm-instruct-alpha-7b",
3
+ "architectures": [
4
+ "JapaneseStableLMAlphaForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "stabilityai/japanese-stablelm-instruct-alpha-7b--configuration_japanese_stablelm_alpha.JapaneseStableLMAlphaConfig",
8
+ "AutoModelForCausalLM": "stabilityai/japanese-stablelm-instruct-alpha-7b--modeling_japanese_stablelm_alpha.JapaneseStableLMAlphaForCausalLM"
9
+ },
10
+ "bos_token_id": 3,
11
+ "classifier_dropout": 0.1,
12
+ "eos_token_id": 3,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 4096,
15
+ "initializer_range": 0.02,
16
+ "layer_norm_eps": 1e-05,
17
+ "max_position_embeddings": 1024,
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 32,
20
+ "rotary_emb_base": 10000,
21
+ "rotary_pct": 0.25,
22
+ "rotary_scale_base": 512,
23
+ "tie_word_embeddings": false,
24
+ "torch_dtype": "float16",
25
+ "transformers_version": "4.30.2",
26
+ "use_bias_in_mlp": false,
27
+ "use_cache": false,
28
+ "use_parallel_residual": true,
29
+ "vocab_size": 65535
30
+ }
configuration_japanese_stablelm_alpha.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ JapaneseStableLMAlpha model configuration"""
16
+
17
+ from transformers import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ STABLE_LM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24
+
25
+
26
+ class JapaneseStableLMAlphaConfig(PretrainedConfig):
27
+ r"""
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ vocab_size (`int`, *optional*, defaults to 65536):
33
+ Vocabulary size of the JapaneseStableLMAlphaModel. Defines the number of different tokens that
34
+ can be represented by the `inputs_ids` passed when calling [`JapaneseStableLMAlphaModel`].
35
+ hidden_size (`int`, *optional*, defaults to 4096):
36
+ Dimension of the decoder layers and the pooler layer.
37
+ num_hidden_layers (`int`, *optional*, defaults to 32):
38
+ Number of hidden layers in the Transformer decoder.
39
+ num_attention_heads (`int`, *optional*, defaults to 32):
40
+ Number of attention heads for each attention layer in the Transformer decoder.
41
+ intermediate_size (`int`, *optional*, defaults to 16384):
42
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer decoder.
43
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
44
+ The non-linear activation function (function or string).
45
+ rotary_pct (`float`, *optional*, defaults to 0.25):
46
+ Percentage of hidden dimensions to allocate to rotary embeddings.
47
+ rotary_emb_base (`int`, *optional*, defaults to 10000)
48
+ Base for computing rotary embeddings frequency.
49
+ rotary_scale_base (`int`, *optional*, defaults to 512)
50
+ Base `scale` for computing XPos rotary embeddings scale.
51
+ classifier_dropout (`float`, *optional*, defaults to 0.1):
52
+ Argument used when doing token classification, used in the model
53
+ [`StableLMForTokenClassification`]. The dropout ratio for the hidden layer.
54
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
55
+ The maximum sequence length that this model might ever be used with.
56
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
57
+ initializer_range (`float`, *optional*, defaults to 1e-5):
58
+ The standard deviation of the truncated_normal_initializer for initializing
59
+ all weight matrices.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
61
+ The epsilon used by the layer normalization layers.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions
64
+ (not used by all models). Only relevant if `config.is_decoder=True`.
65
+ use_parallel_residual (`bool`, *optional*, defaults to `True`):
66
+ Whether to use a "parallel" formulation in each Transformer layer,
67
+ which can provide a slight training speedup at large scales.
68
+ Example:
69
+
70
+ ```python
71
+ >>> from transformers import JapaneseStableLMAlphaConfig, JapaneseStableLMAlphaModel
72
+
73
+ >>> # Initializing a JapaneseStableLMAlpha style configuration
74
+ >>> configuration = JapaneseStableLMAlphaConfig()
75
+
76
+ >>> # Initializing a model (with random weights) from the style configuration
77
+ >>> model = JapaneseStableLMAlphaModel(configuration) # doctest: +SKIP
78
+
79
+ >>> # Accessing the model configuration
80
+ >>> configuration = model.config # doctest: +SKIP
81
+ ```"""
82
+ def __init__(
83
+ self,
84
+ vocab_size=65536,
85
+ hidden_size=4096,
86
+ num_hidden_layers=32,
87
+ num_attention_heads=32,
88
+ hidden_act="silu",
89
+ rotary_pct=0.25,
90
+ rotary_emb_base=10000,
91
+ rotary_scale_base=512,
92
+ classifier_dropout=0.1,
93
+ max_position_embeddings=2048,
94
+ initializer_range=0.02,
95
+ layer_norm_eps=1e-5,
96
+ use_cache=True,
97
+ bos_token_id=3,
98
+ eos_token_id=3,
99
+ tie_word_embeddings=False,
100
+ use_parallel_residual=True,
101
+ use_bias_in_mlp=True,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
105
+ self.vocab_size = vocab_size
106
+ self.max_position_embeddings = max_position_embeddings
107
+ self.hidden_size = hidden_size
108
+ self.num_hidden_layers = num_hidden_layers
109
+ self.num_attention_heads = num_attention_heads
110
+ self.hidden_act = hidden_act
111
+ self.rotary_pct = rotary_pct
112
+ self.rotary_emb_base = rotary_emb_base
113
+ self.rotary_scale_base = rotary_scale_base
114
+ self.classifier_dropout = classifier_dropout
115
+ self.initializer_range = initializer_range
116
+ self.layer_norm_eps = layer_norm_eps
117
+ self.use_cache = use_cache
118
+ self.tie_word_embeddings = tie_word_embeddings
119
+ self.use_parallel_residual = use_parallel_residual
120
+ self.use_bias_in_mlp = use_bias_in_mlp
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 3,
4
+ "eos_token_id": 3,
5
+ "transformers_version": "4.30.2"
6
+ }
modeling_japanese_stablelm_alpha.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch JapaneseStableLMAlpha model. """
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+ from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class JapaneseStableLMAlphaPreTrainedModel(PreTrainedModel):
35
+ """
36
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
37
+ models.
38
+ """
39
+
40
+ config_class = JapaneseStableLMAlphaConfig
41
+ base_model_prefix = "transformer"
42
+ supports_gradient_checkpointing = True
43
+ _no_split_modules = ["DecoderLayer"]
44
+ _skip_keys_device_placement = "past_key_values"
45
+
46
+ def _init_weights(self, module):
47
+ """Initialize the weights"""
48
+ if isinstance(module, nn.Linear):
49
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
50
+ if module.bias is not None:
51
+ module.bias.data.zero_()
52
+ elif isinstance(module, nn.Embedding):
53
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
54
+ if module.padding_idx is not None:
55
+ module.weight.data[module.padding_idx].zero_()
56
+ elif isinstance(module, nn.LayerNorm):
57
+ if module.bias is not None:
58
+ module.bias.data.zero_()
59
+ if module.weight is not None:
60
+ module.weight.data.fill_(1.0)
61
+
62
+ def _set_gradient_checkpointing(self, module, value=False):
63
+ if isinstance(module, JapaneseStableLMAlphaModel):
64
+ module.gradient_checkpointing = value
65
+
66
+
67
+ class JapaneseStableLMAlphaModel(JapaneseStableLMAlphaPreTrainedModel):
68
+ def __init__(self, config):
69
+ super().__init__(config)
70
+ self.config = config
71
+
72
+ self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
73
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
74
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
75
+
76
+ self.gradient_checkpointing = False
77
+
78
+ # Initialize weights and apply final processing
79
+ self.post_init()
80
+
81
+ def get_input_embeddings(self):
82
+ return self.embed_in
83
+
84
+ def set_input_embeddings(self, value):
85
+ self.embed_in = value
86
+
87
+ def forward(
88
+ self,
89
+ input_ids: Optional[torch.LongTensor] = None,
90
+ attention_mask: Optional[torch.FloatTensor] = None,
91
+ position_ids: Optional[torch.LongTensor] = None,
92
+ head_mask: Optional[torch.FloatTensor] = None,
93
+ inputs_embeds: Optional[torch.FloatTensor] = None,
94
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
95
+ use_cache: Optional[bool] = None,
96
+ output_attentions: Optional[bool] = None,
97
+ output_hidden_states: Optional[bool] = None,
98
+ return_dict: Optional[bool] = None,
99
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
100
+ r"""
101
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
102
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
103
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
104
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
105
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
106
+ use_cache (`bool`, *optional*):
107
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
108
+ `past_key_values`).
109
+ """
110
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
111
+ output_hidden_states = (
112
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
113
+ )
114
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
115
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
116
+
117
+ if input_ids is not None and inputs_embeds is not None:
118
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
119
+ elif input_ids is not None:
120
+ input_shape = input_ids.size()
121
+ elif inputs_embeds is not None:
122
+ input_shape = inputs_embeds.size()[:-1]
123
+ else:
124
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
125
+
126
+ batch_size, seq_length = input_shape
127
+
128
+ if past_key_values is None:
129
+ past_length = 0
130
+ past_key_values = tuple([None] * self.config.num_hidden_layers)
131
+ else:
132
+ past_length = past_key_values[0][0].size(-2)
133
+
134
+ if position_ids is None:
135
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
136
+ position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
137
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
138
+ else:
139
+ position_ids = position_ids.view(-1, seq_length).long()
140
+
141
+ # Attention mask.
142
+ if attention_mask is not None:
143
+ assert batch_size > 0, "batch_size has to be defined and > 0"
144
+ attention_mask = attention_mask.view(batch_size, -1)
145
+ # We create a 3D attention mask from a 2D tensor mask.
146
+ # Sizes are [batch_size, 1, 1, to_seq_length]
147
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
148
+ # this attention mask is more simple than the triangular masking of causal attention
149
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
150
+ attention_mask = attention_mask[:, None, None, :]
151
+
152
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
153
+ # masked positions, this operation will create a tensor which is 0.0 for
154
+ # positions we want to attend and the dtype's smallest value for masked positions.
155
+ # Since we are adding it to the raw scores before the softmax, this is
156
+ # effectively the same as removing these entirely.
157
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
158
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
159
+
160
+ # Prepare head mask if needed
161
+ # 1.0 in head_mask indicate we keep the head
162
+ # attention_probs has shape bsz x n_heads x N x N
163
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
164
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
165
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
166
+
167
+ if inputs_embeds is None:
168
+ inputs_embeds = self.embed_in(input_ids)
169
+
170
+ hidden_states = inputs_embeds
171
+
172
+ if self.gradient_checkpointing and self.training:
173
+ if use_cache:
174
+ logger.warning(
175
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
176
+ )
177
+ use_cache = False
178
+
179
+ presents = () if use_cache else None
180
+ all_attentions = () if output_attentions else None
181
+ all_hidden_states = () if output_hidden_states else None
182
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
183
+ if output_hidden_states:
184
+ all_hidden_states = all_hidden_states + (hidden_states,)
185
+
186
+ if self.gradient_checkpointing and self.training:
187
+
188
+ def create_custom_forward(module):
189
+ def custom_forward(*inputs):
190
+ # None for layer_past
191
+ return module(*inputs, use_cache, None, output_attentions)
192
+
193
+ return custom_forward
194
+
195
+ outputs = torch.utils.checkpoint.checkpoint(
196
+ create_custom_forward(layer),
197
+ hidden_states,
198
+ attention_mask,
199
+ position_ids,
200
+ head_mask[i],
201
+ )
202
+ else:
203
+ outputs = layer(
204
+ hidden_states,
205
+ attention_mask=attention_mask,
206
+ position_ids=position_ids,
207
+ head_mask=head_mask[i],
208
+ layer_past=layer_past,
209
+ use_cache=use_cache,
210
+ output_attentions=output_attentions,
211
+ )
212
+ hidden_states = outputs[0]
213
+ if use_cache is True:
214
+ presents = presents + (outputs[1],)
215
+ if output_attentions:
216
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
217
+
218
+ hidden_states = self.final_layer_norm(hidden_states)
219
+ # Add last hidden state
220
+ if output_hidden_states:
221
+ all_hidden_states = all_hidden_states + (hidden_states,)
222
+
223
+ if not return_dict:
224
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
225
+
226
+ return BaseModelOutputWithPast(
227
+ last_hidden_state=hidden_states,
228
+ past_key_values=presents,
229
+ hidden_states=all_hidden_states,
230
+ attentions=all_attentions,
231
+ )
232
+
233
+
234
+ class DecoderLayer(nn.Module):
235
+ def __init__(self, config):
236
+ super().__init__()
237
+ self.use_parallel_residual = config.use_parallel_residual
238
+ self.input_layernorm = nn.LayerNorm(
239
+ config.hidden_size,
240
+ eps=config.layer_norm_eps,
241
+ elementwise_affine=False,
242
+ )
243
+ self.post_attention_layernorm = nn.LayerNorm(
244
+ config.hidden_size,
245
+ eps=config.layer_norm_eps
246
+ )
247
+ self.attention = Attention(config)
248
+ self.mlp = MLP(config)
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: Optional[torch.FloatTensor],
253
+ attention_mask: Optional[torch.FloatTensor] = None,
254
+ position_ids: Optional[torch.LongTensor] = None,
255
+ head_mask: Optional[torch.FloatTensor] = None,
256
+ use_cache: Optional[bool] = False,
257
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
258
+ output_attentions: Optional[bool] = False,
259
+ ):
260
+ attention_layer_outputs = self.attention(
261
+ self.input_layernorm(hidden_states),
262
+ attention_mask=attention_mask,
263
+ position_ids=position_ids,
264
+ layer_past=layer_past,
265
+ head_mask=head_mask,
266
+ use_cache=use_cache,
267
+ output_attentions=output_attentions,
268
+ )
269
+ attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
270
+ outputs = attention_layer_outputs[1:]
271
+
272
+ mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
273
+ hidden_states = hidden_states + mlp_output + attn_output
274
+
275
+ if use_cache:
276
+ outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
277
+ else:
278
+ outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
279
+
280
+ return outputs
281
+
282
+
283
+ class MLP(nn.Module):
284
+ def __init__(self, config: JapaneseStableLMAlphaConfig):
285
+ super().__init__()
286
+ hidden_size = config.hidden_size
287
+ multiple_of = 256
288
+ ff_dim = int(8 * hidden_size / 3)
289
+ intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
290
+
291
+ self.packed_input_proj = torch.nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
292
+ self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
293
+ self.act = nn.SiLU()
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ ff, ff_gate = self.packed_input_proj(x).chunk(2, dim=-1)
297
+ return self.out_proj(ff * self.act(ff_gate))
298
+
299
+
300
+ class RotaryEmbedding(torch.nn.Module):
301
+ """Based on Tri Dao's XPos: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py"""
302
+ def __init__(
303
+ self,
304
+ dim: int,
305
+ max_position_embeddings: int,
306
+ base: int = 10_000,
307
+ scale_base: int = 512,
308
+ device: str = None
309
+ ):
310
+ super().__init__()
311
+ self.dim = dim
312
+ self.seq_len_cached = max_position_embeddings
313
+
314
+ # Set up `inv_freq` term
315
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
316
+ self.register_buffer("inv_freq", inv_freq)
317
+
318
+ # Set up `scale` term
319
+ self.scale_base = scale_base
320
+ scale = (
321
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
322
+ if scale_base is not None else None
323
+ )
324
+ self.register_buffer("scale", scale)
325
+
326
+ # Seet up `cos..` and `sin...` cache terms
327
+ t = torch.arange(self.seq_len_cached, device=device, dtype=torch.float32)
328
+ freqs = torch.outer(t, self.inv_freq)
329
+ # freqs = torch.cat((freqs, freqs), dim=-1)
330
+ seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
331
+ power = (seq_range - self.seq_len_cached // 2) / self.scale_base
332
+ scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
333
+ # scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
334
+ self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
335
+ self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
336
+ self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
337
+ self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
338
+
339
+ def forward(self, x, seq_len=None):
340
+ if seq_len > self.seq_len_cached:
341
+ self.seq_len_cached = seq_len
342
+ t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
343
+ freqs = torch.outer(t, self.inv_freq)
344
+ freqs = torch.cat((freqs, freqs), dim=-1)
345
+ seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
346
+ power = (seq_range - self.seq_len_cached // 2) / self.scale_base
347
+ scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
348
+ scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
349
+ self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
350
+ self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
351
+ self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
352
+ self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
353
+ return (
354
+ self.cos_cached[:seq_len, ...],
355
+ self.sin_cached[:seq_len, ...],
356
+ self.cos_k_cached[:seq_len, ...],
357
+ self.sin_k_cached[:seq_len, ...],
358
+ )
359
+
360
+
361
+ def rotate_half(x):
362
+ x1, x2 = x.chunk(2, dim=-1)
363
+ return torch.cat((-x2, x1), dim=-1)
364
+
365
+
366
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, cos_k=None, sin_k=None):
367
+ """
368
+ q, k: [bs, num_heads, seq_len, rot_dim]
369
+ cos, sin: [seq_len, rot_dim / 2]
370
+ position_ids: [bs, seq_len]
371
+ """
372
+ # print(f"q: {q.shape}, k: {k.shape}, cos: {cos.shape}, sin: {sin.shape}, position_ids: {position_ids.shape}")
373
+ import einops
374
+ cos = einops.repeat(cos, 's r -> s (2 r)')
375
+ sin = einops.repeat(sin, 's r -> s (2 r)')
376
+ cos_k = einops.repeat(cos_k, 's r -> s (2 r)')
377
+ sin_k = einops.repeat(sin_k, 's r -> s (2 r)')
378
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
379
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
380
+ cos_k = cos_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
381
+ sin_k = sin_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
382
+
383
+ q_embed = (q * cos) + (rotate_half(q) * sin)
384
+ k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
385
+ return q_embed, k_embed
386
+
387
+
388
+ class Attention(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.num_attention_heads = config.num_attention_heads
392
+ self.hidden_size = config.hidden_size
393
+ if self.hidden_size % self.num_attention_heads != 0:
394
+ raise ValueError(
395
+ "The hidden size is not divisble by the number of attention heads! Make sure to update them"
396
+ )
397
+ self.head_size = self.hidden_size // self.num_attention_heads
398
+
399
+ max_positions = config.max_position_embeddings
400
+ self.register_buffer(
401
+ "bias",
402
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
403
+ 1, 1, max_positions, max_positions
404
+ ),
405
+ persistent=False,
406
+ )
407
+ self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
408
+
409
+ self.rotary_ndims = int(self.head_size * config.rotary_pct)
410
+ self.rotary_emb = RotaryEmbedding(
411
+ self.rotary_ndims,
412
+ max_position_embeddings=config.max_position_embeddings,
413
+ base=config.rotary_emb_base,
414
+ scale_base=config.rotary_scale_base,
415
+ )
416
+
417
+ self.register_buffer(
418
+ "norm_factor",
419
+ torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
420
+ persistent=False,
421
+ )
422
+
423
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
424
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states: torch.FloatTensor,
429
+ attention_mask: torch.FloatTensor,
430
+ position_ids: torch.LongTensor,
431
+ head_mask: Optional[torch.FloatTensor] = None,
432
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
433
+ use_cache: Optional[bool] = False,
434
+ output_attentions: Optional[bool] = False,
435
+ ):
436
+ has_layer_past = layer_past is not None
437
+
438
+ # Compute QKV
439
+ # Attention heads [batch, seq_len, hidden_size]
440
+ # --> [batch, seq_len, (np * 3 * head_size)]
441
+ qkv = self.query_key_value(hidden_states)
442
+
443
+ # [batch, seq_len, (num_heads * 3 * head_size)]
444
+ # --> [batch, seq_len, num_heads, 3 * head_size]
445
+ new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
446
+ qkv = qkv.view(*new_qkv_shape)
447
+
448
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
449
+ query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
450
+ key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
451
+ value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
452
+
453
+ # Compute rotary embeddings on rotary_ndims
454
+ query_rot = query[..., : self.rotary_ndims]
455
+ query_pass = query[..., self.rotary_ndims :]
456
+ key_rot = key[..., : self.rotary_ndims]
457
+ key_pass = key[..., self.rotary_ndims :]
458
+
459
+ # Compute token offset for rotary embeddings (when decoding)
460
+ kv_seq_len = key.shape[-2]
461
+ if has_layer_past:
462
+ kv_seq_len += layer_past[0].shape[-2]
463
+
464
+ # Add rotary embeddings to query and key
465
+ # TODO: Check if using xpos
466
+ cos, sin, cos_k, sin_k = self.rotary_emb(value, seq_len=kv_seq_len)
467
+ query, key = apply_rotary_pos_emb(
468
+ query_rot, key_rot, cos, sin, position_ids, cos_k=cos_k, sin_k=sin_k)
469
+
470
+ query = torch.cat((query, query_pass), dim=-1)
471
+ key = torch.cat((key, key_pass), dim=-1)
472
+
473
+ # Cache QKV values
474
+ if has_layer_past:
475
+ past_key = layer_past[0]
476
+ past_value = layer_past[1]
477
+ key = torch.cat((past_key, key), dim=-2)
478
+ value = torch.cat((past_value, value), dim=-2)
479
+ present = (key, value) if use_cache else None
480
+
481
+ # Compute attention
482
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
483
+
484
+ # Merge attn_head_size dim and num_attn_heads dim into hidden dim
485
+ # [bs, seq_len, num_attention_heads, attn_head_size]
486
+ attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
487
+ attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.num_attention_heads * self.head_size)
488
+
489
+ attn_output = self.dense(attn_output)
490
+
491
+ outputs = (attn_output, present)
492
+ if output_attentions:
493
+ outputs += (attn_weights,)
494
+
495
+ return outputs
496
+
497
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
498
+ # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
499
+ # compute causal mask from causal mask buffer
500
+
501
+ batch_size, num_attention_heads, query_length, attn_head_size = query.size()
502
+ key_length = key.size(-2)
503
+
504
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
505
+
506
+ query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
507
+ key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
508
+ attn_scores = torch.zeros(
509
+ batch_size * num_attention_heads,
510
+ query_length,
511
+ key_length,
512
+ dtype=query.dtype,
513
+ device=key.device,
514
+ )
515
+ attn_scores = torch.baddbmm(
516
+ attn_scores,
517
+ query,
518
+ key.transpose(1, 2),
519
+ beta=1.0,
520
+ alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
521
+ )
522
+ attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
523
+
524
+ mask_value = torch.finfo(attn_scores.dtype).min
525
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
526
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
527
+ mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype, device=attn_scores.device)
528
+ attn_scores = torch.where(causal_mask, attn_scores, mask_value)
529
+
530
+ if attention_mask is not None:
531
+ # Apply the attention mask
532
+ attn_scores = attn_scores + attention_mask
533
+
534
+ # NOTE: Upcast to float32
535
+ attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).type_as(value)
536
+
537
+ # Mask heads if we want to
538
+ if head_mask is not None:
539
+ attn_weights = attn_weights * head_mask
540
+
541
+ attn_output = torch.matmul(attn_weights, value)
542
+ return attn_output, attn_weights
543
+
544
+
545
+ def attention_mask_func(attention_scores, ltor_mask):
546
+ attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
547
+ return attention_scores
548
+
549
+
550
+ class JapaneseStableLMAlphaForCausalLM(JapaneseStableLMAlphaPreTrainedModel):
551
+ _tied_weights_keys = ["embed_out.weight"]
552
+
553
+ def __init__(self, config):
554
+ super().__init__(config)
555
+
556
+ self.transformer = JapaneseStableLMAlphaModel(config)
557
+ self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
558
+
559
+ # Initialize weights and apply final processing
560
+ self.post_init()
561
+
562
+ def get_output_embeddings(self):
563
+ return self.embed_out
564
+
565
+ def set_output_embeddings(self, new_embeddings):
566
+ self.embed_out = new_embeddings
567
+
568
+ def forward(
569
+ self,
570
+ input_ids: Optional[torch.LongTensor] = None,
571
+ attention_mask: Optional[torch.FloatTensor] = None,
572
+ position_ids: Optional[torch.LongTensor] = None,
573
+ inputs_embeds: Optional[torch.FloatTensor] = None,
574
+ head_mask: Optional[torch.FloatTensor] = None,
575
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
576
+ labels: Optional[torch.LongTensor] = None,
577
+ use_cache: Optional[bool] = None,
578
+ output_attentions: Optional[bool] = None,
579
+ output_hidden_states: Optional[bool] = None,
580
+ return_dict: Optional[bool] = None,
581
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
582
+ r"""
583
+ Example:
584
+
585
+ ```python
586
+ >>> import torch
587
+ >>> from transformers import LlamaTokenizer, JapaneseStableLMAlphaForCausalLM, JapaneseStableLMAlphaConfig
588
+
589
+ >>> tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1")
590
+ >>> config = JapaneseStableLMAlphaConfig.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b")
591
+ >>> config.is_decoder = True
592
+ >>> model = JapaneseStableLMAlphaForCausalLM.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b", config=config, trust_remote_code=True)
593
+
594
+ >>> inputs = tokenizer("日本語の美しいところは、", return_tensors="pt")
595
+ >>> outputs = model(**inputs)
596
+
597
+ >>> prediction_logits = outputs.logits
598
+ ```"""
599
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
600
+
601
+ outputs = self.transformer(
602
+ input_ids,
603
+ attention_mask=attention_mask,
604
+ position_ids=position_ids,
605
+ head_mask=head_mask,
606
+ inputs_embeds=inputs_embeds,
607
+ past_key_values=past_key_values,
608
+ use_cache=use_cache,
609
+ output_attentions=output_attentions,
610
+ output_hidden_states=output_hidden_states,
611
+ return_dict=return_dict,
612
+ )
613
+
614
+ hidden_states = outputs[0]
615
+ lm_logits = self.embed_out(hidden_states)
616
+
617
+ lm_loss = None
618
+ if labels is not None:
619
+ # move labels to correct device to enable model parallelism
620
+ labels = labels.to(lm_logits.device)
621
+ # we are doing next-token prediction; shift prediction scores and input ids by one
622
+ shift_logits = lm_logits[:, :-1, :].contiguous()
623
+ labels = labels[:, 1:].contiguous()
624
+ loss_fct = CrossEntropyLoss()
625
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
626
+
627
+ if not return_dict:
628
+ output = (lm_logits,) + outputs[1:]
629
+ return ((lm_loss,) + output) if lm_loss is not None else output
630
+
631
+ return CausalLMOutputWithPast(
632
+ loss=lm_loss,
633
+ logits=lm_logits,
634
+ past_key_values=outputs.past_key_values,
635
+ hidden_states=outputs.hidden_states,
636
+ attentions=outputs.attentions,
637
+ )
638
+
639
+ def prepare_inputs_for_generation(
640
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
641
+ ):
642
+ input_shape = input_ids.shape
643
+
644
+ # cut decoder_input_ids if past is used
645
+ if past_key_values and past_key_values[0] is not None:
646
+ input_ids = input_ids[:, -1:]
647
+
648
+ position_ids = kwargs.get("position_ids", None)
649
+ if attention_mask is not None and position_ids is None:
650
+ # create position_ids on the fly for batch generation
651
+ position_ids = attention_mask.long().cumsum(-1) - 1
652
+ position_ids.masked_fill_(attention_mask == 0, 1)
653
+ if past_key_values:
654
+ position_ids = position_ids[:, -1].unsqueeze(-1)
655
+
656
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
657
+ if attention_mask is None:
658
+ attention_mask = input_ids.new_ones(input_shape)
659
+
660
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
661
+ if inputs_embeds is not None and past_key_values is None:
662
+ model_inputs = {"inputs_embeds": inputs_embeds}
663
+ else:
664
+ model_inputs = {"input_ids": input_ids}
665
+
666
+ model_inputs.update(
667
+ {
668
+ "attention_mask": attention_mask,
669
+ "past_key_values": past_key_values,
670
+ "position_ids": position_ids,
671
+ }
672
+ )
673
+
674
+ return model_inputs
675
+
676
+ def _reorder_cache(self, past_key_values, beam_idx):
677
+ reordered_past = ()
678
+ for layer_past in past_key_values:
679
+ reordered_past += (
680
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
681
+ )
682
+ return reordered_past
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:431f0de2749eac9f61b583f53295352192b26b861d0194832164aa9825ad8d10
3
+ size 14026364945