jeffra commited on
Commit
aa6eb8c
·
verified ·
1 Parent(s): 67c9921

Upload configuration_arctic.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_arctic.py +216 -0
configuration_arctic.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Snowflake AI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ Arctic model configuration"""
15
+
16
+ from dataclasses import asdict, dataclass
17
+ from typing import Any, Dict
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
+ "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json",
27
+ }
28
+
29
+
30
+ @dataclass
31
+ class ArcticLoraConfig:
32
+ lora_r: int = 64
33
+ lora_alpha: float = 16
34
+ shard_base_weights: bool = False
35
+
36
+
37
+ @dataclass
38
+ class ArcticQuantizationConfig:
39
+ q_bits: int = 8
40
+ rounding: str = "nearest"
41
+ mantissa_bits: int = 3
42
+ group_size: int = 512
43
+
44
+
45
+ class ArcticConfig(PretrainedConfig):
46
+ r"""
47
+ This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an
48
+ Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration
49
+ with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config..
50
+
51
+
52
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
53
+ documentation from [`PretrainedConfig`] for more information.
54
+
55
+
56
+ Args:
57
+ vocab_size (`int`, *optional*, defaults to 32000):
58
+ Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the
59
+ `inputs_ids` passed when calling [`ArcticModel`]
60
+ hidden_size (`int`, *optional*, defaults to 4096):
61
+ Dimension of the hidden representations.
62
+ intermediate_size (`int`, *optional*, defaults to 14336):
63
+ Dimension of the MLP representations.
64
+ num_hidden_layers (`int`, *optional*, defaults to 32):
65
+ Number of hidden layers in the Transformer encoder.
66
+ num_attention_heads (`int`, *optional*, defaults to 32):
67
+ Number of attention heads for each attention layer in the Transformer encoder.
68
+ num_key_value_heads (`int`, *optional*, defaults to 8):
69
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
70
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
71
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
72
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
73
+ by meanpooling all the original heads within that group. For more details checkout [this
74
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
75
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
76
+ The non-linear activation function (function or string) in the decoder.
77
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
78
+ The maximum sequence length that this model might ever be used with. Arctic's sliding window attention
79
+ allows sequence of up to 4096*32 tokens.
80
+ initializer_range (`float`, *optional*, defaults to 0.02):
81
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
82
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
83
+ The epsilon used by the rms normalization layers.
84
+ use_cache (`bool`, *optional*, defaults to `True`):
85
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
86
+ relevant if `config.is_decoder=True`.
87
+ pad_token_id (`int`, *optional*):
88
+ The id of the padding token.
89
+ bos_token_id (`int`, *optional*, defaults to 1):
90
+ The id of the "beginning-of-sequence" token.
91
+ eos_token_id (`int`, *optional*, defaults to 2):
92
+ The id of the "end-of-sequence" token.
93
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
94
+ Whether the model's input and output word embeddings should be tied.
95
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
96
+ The base period of the RoPE embeddings.
97
+ sliding_window (`int`, *optional*):
98
+ Sliding window attention window size. If not specified, will default to `4096`.
99
+ attention_dropout (`float`, *optional*, defaults to 0.0):
100
+ The dropout ratio for the attention probabilities.
101
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
102
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
103
+ parameter
104
+ num_local_experts (`int`, *optional*, defaults to 8):
105
+ Number of experts per Sparse MLP layer.
106
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
107
+ The aux loss factor for the total loss.
108
+
109
+ ```python
110
+ >>> from transformers import ArcticModel, ArcticConfig
111
+
112
+ >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to.
113
+ >>> configuration = ArcticConfig()
114
+
115
+ >>> # Initializing a model from the Arctic 7B style configuration
116
+ >>> model = ArcticModel(configuration)
117
+
118
+ >>> # Accessing the model configuration
119
+ >>> configuration = model.config
120
+ ```"""
121
+
122
+ model_type = "arctic"
123
+ keys_to_ignore_at_inference = ["past_key_values"]
124
+
125
+ def __init__(
126
+ self,
127
+ vocab_size=32000,
128
+ hidden_size=4096,
129
+ intermediate_size=14336,
130
+ num_hidden_layers=32,
131
+ num_attention_heads=32,
132
+ num_key_value_heads=None,
133
+ hidden_act="silu",
134
+ max_position_embeddings=4096,
135
+ initializer_range=0.02,
136
+ rms_norm_eps=1e-5,
137
+ use_cache=True,
138
+ pad_token_id=None,
139
+ bos_token_id=1,
140
+ eos_token_id=2,
141
+ tie_word_embeddings=False,
142
+ rope_theta=1e6,
143
+ sliding_window=None,
144
+ attention_dropout=0.0,
145
+ num_experts_per_tok=1,
146
+ num_local_experts=8,
147
+ router_aux_loss_coef=0.001,
148
+ moe_layer_frequency=2,
149
+ parallel_attn_mlp_res=False,
150
+ moe_train_capacity_factor=1,
151
+ moe_eval_capacity_factor=1,
152
+ enable_expert_tensor_parallelism=False,
153
+ moe_min_capacity=0,
154
+ moe_token_dropping=True,
155
+ quantization=None,
156
+ **kwargs,
157
+ ):
158
+ self.vocab_size = vocab_size
159
+ self.max_position_embeddings = max_position_embeddings
160
+ self.hidden_size = hidden_size
161
+ self.intermediate_size = intermediate_size
162
+ self.num_hidden_layers = num_hidden_layers
163
+ self.num_attention_heads = num_attention_heads
164
+ self.sliding_window = sliding_window
165
+
166
+ # for backward compatibility
167
+ if num_key_value_heads is None:
168
+ num_key_value_heads = num_attention_heads
169
+
170
+ self.num_key_value_heads = num_key_value_heads
171
+ self.hidden_act = hidden_act
172
+ self.initializer_range = initializer_range
173
+ self.rms_norm_eps = rms_norm_eps
174
+ self.use_cache = use_cache
175
+ self.rope_theta = rope_theta
176
+ self.attention_dropout = attention_dropout
177
+
178
+ self.num_experts_per_tok = num_experts_per_tok
179
+ self.num_local_experts = num_local_experts
180
+ self.router_aux_loss_coef = router_aux_loss_coef
181
+ self.moe_layer_frequency = moe_layer_frequency
182
+ self.moe_train_capacity_factor = moe_train_capacity_factor
183
+ self.moe_eval_capacity_factor = moe_eval_capacity_factor
184
+ self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism
185
+ self.moe_min_capacity = moe_min_capacity
186
+ self.moe_token_dropping = moe_token_dropping
187
+ self.parallel_attn_mlp_res = parallel_attn_mlp_res
188
+ if isinstance(quantization, dict):
189
+ self.quantization = ArcticQuantizationConfig(**quantization)
190
+ else:
191
+ self.quantization = quantization
192
+
193
+ super().__init__(
194
+ pad_token_id=pad_token_id,
195
+ bos_token_id=bos_token_id,
196
+ eos_token_id=eos_token_id,
197
+ tie_word_embeddings=tie_word_embeddings,
198
+ **kwargs,
199
+ )
200
+
201
+ @classmethod
202
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig":
203
+ result = super().from_dict(config_dict, **kwargs)
204
+ if isinstance(result, tuple):
205
+ config = result[0]
206
+ else:
207
+ config = result
208
+ if isinstance(config.quantization, dict):
209
+ config.quantization = ArcticQuantizationConfig(**config.quantization)
210
+ return result
211
+
212
+ def to_dict(self) -> Dict[str, Any]:
213
+ ret = super().to_dict()
214
+ if isinstance(ret["quantization"], ArcticQuantizationConfig):
215
+ ret["quantization"] = asdict(ret["quantization"])
216
+ return ret