Text Generation
Transformers
PyTorch
Safetensors
English
gpt_refact
code
custom_code
Eval Results
svakhreev commited on
Commit
94a3f6d
·
1 Parent(s): 38cebfc

Upload GPTRefactForCausalLM

Browse files
config.json CHANGED
@@ -2,15 +2,13 @@
2
  "architectures": [
3
  "GPTRefactForCausalLM"
4
  ],
5
- "attention_softmax_in_fp32": false,
6
- "attn_pdrop": 0.1,
7
  "auto_map": {
8
  "AutoConfig": "configuration_gpt_refact.GPTRefactConfig",
9
  "AutoModelForCausalLM": "modeling_gpt_refact.GPTRefactForCausalLM"
10
  },
11
- "bos_token_id": -1,
12
  "do_sample": true,
13
- "embd_pdrop": 0.1,
14
  "eos_token_id": 0,
15
  "initializer_range": 0.02,
16
  "layer_norm_epsilon": 1e-05,
@@ -21,10 +19,8 @@
21
  "n_inner": null,
22
  "n_layer": 32,
23
  "n_positions": 4096,
24
- "resid_pdrop": 0.1,
25
- "scale_attention_softmax_in_fp32": false,
26
- "scale_attn_weights": true,
27
- "torch_dtype": "float16",
28
  "transformers_version": "4.31.0",
29
  "use_cache": true,
30
  "vocab_size": 49216
 
2
  "architectures": [
3
  "GPTRefactForCausalLM"
4
  ],
5
+ "attention_bias_in_fp32": true,
6
+ "attention_softmax_in_fp32": true,
7
  "auto_map": {
8
  "AutoConfig": "configuration_gpt_refact.GPTRefactConfig",
9
  "AutoModelForCausalLM": "modeling_gpt_refact.GPTRefactForCausalLM"
10
  },
 
11
  "do_sample": true,
 
12
  "eos_token_id": 0,
13
  "initializer_range": 0.02,
14
  "layer_norm_epsilon": 1e-05,
 
19
  "n_inner": null,
20
  "n_layer": 32,
21
  "n_positions": 4096,
22
+ "scale_attention_softmax_in_fp32": true,
23
+ "torch_dtype": "bfloat16",
 
 
24
  "transformers_version": "4.31.0",
25
  "use_cache": true,
26
  "vocab_size": 49216
configuration_gpt_refact.py CHANGED
@@ -1,7 +1,6 @@
1
  from transformers.configuration_utils import PretrainedConfig
2
  from transformers.utils import logging
3
 
4
-
5
  logger = logging.get_logger(__name__)
6
 
7
 
@@ -16,26 +15,23 @@ class GPTRefactConfig(PretrainedConfig):
16
  }
17
 
18
  def __init__(
19
- self,
20
- vocab_size: int = 49216,
21
- n_positions: int = 4096,
22
- n_embd: int = 1024,
23
- n_layer: int = 32,
24
- n_head: int = 64,
25
- max_position_embeddings: int = 4096,
26
- multi_query: bool = True,
27
- layer_norm_epsilon=1e-5,
28
- initializer_range=0.02,
29
- scale_attn_weights=True,
30
- use_cache=True,
31
- bos_token_id=-1,
32
- eos_token_id=0,
33
- attention_softmax_in_fp32=True,
34
- scale_attention_softmax_in_fp32=True,
35
- resid_pdrop=0.1,
36
- embd_pdrop=0.1,
37
- attn_pdrop=0.1,
38
- **kwargs,
39
  ):
40
  self.vocab_size = vocab_size
41
  self.n_positions = n_positions
@@ -43,19 +39,13 @@ class GPTRefactConfig(PretrainedConfig):
43
  self.n_layer = n_layer
44
  self.n_head = n_head
45
  self.n_inner = None
46
- self.resid_pdrop = resid_pdrop
47
- self.embd_pdrop = embd_pdrop
48
- self.attn_pdrop = attn_pdrop
49
  self.layer_norm_epsilon = layer_norm_epsilon
50
  self.initializer_range = initializer_range
51
- self.scale_attn_weights = scale_attn_weights
52
  self.use_cache = use_cache
53
  self.attention_softmax_in_fp32 = attention_softmax_in_fp32
54
  self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
55
-
56
- self.bos_token_id = bos_token_id
57
- self.eos_token_id = eos_token_id
58
-
59
  self.multi_query = multi_query
60
  self.max_position_embeddings = max_position_embeddings
61
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
 
1
  from transformers.configuration_utils import PretrainedConfig
2
  from transformers.utils import logging
3
 
 
4
  logger = logging.get_logger(__name__)
5
 
6
 
 
15
  }
16
 
17
  def __init__(
18
+ self,
19
+ vocab_size: int = 49216,
20
+ n_positions: int = 4096,
21
+ n_embd: int = 1024,
22
+ n_layer: int = 32,
23
+ n_head: int = 64,
24
+ max_position_embeddings: int = 4096,
25
+ multi_query: bool = True,
26
+ layer_norm_epsilon: float = 1e-5,
27
+ initializer_range: float = 0.02,
28
+ use_cache: bool = True,
29
+ eos_token_id: int = 0,
30
+ attention_softmax_in_fp32: bool = True,
31
+ scale_attention_softmax_in_fp32: bool = True,
32
+ attention_bias_in_fp32: bool = True,
33
+ torch_dtype: str = 'bfloat16',
34
+ **kwargs,
 
 
 
35
  ):
36
  self.vocab_size = vocab_size
37
  self.n_positions = n_positions
 
39
  self.n_layer = n_layer
40
  self.n_head = n_head
41
  self.n_inner = None
 
 
 
42
  self.layer_norm_epsilon = layer_norm_epsilon
43
  self.initializer_range = initializer_range
 
44
  self.use_cache = use_cache
45
  self.attention_softmax_in_fp32 = attention_softmax_in_fp32
46
  self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
47
+ self.attention_bias_in_fp32 = attention_bias_in_fp32
 
 
 
48
  self.multi_query = multi_query
49
  self.max_position_embeddings = max_position_embeddings
50
+ self.torch_dtype = torch_dtype
51
+ super().__init__(eos_token_id=eos_token_id, **kwargs)
generation_config.json CHANGED
@@ -1,6 +1,5 @@
1
  {
2
  "_from_model_config": true,
3
- "bos_token_id": -1,
4
  "do_sample": true,
5
  "eos_token_id": 0,
6
  "transformers_version": "4.31.0"
 
1
  {
2
  "_from_model_config": true,
 
3
  "do_sample": true,
4
  "eos_token_id": 0,
5
  "transformers_version": "4.31.0"
modeling_gpt_refact.py CHANGED
@@ -21,29 +21,23 @@ logger = logging.get_logger(__name__)
21
 
22
  @torch.jit.script
23
  def upcast_masked_softmax(
24
- x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
25
  ):
26
  input_dtype = x.dtype
27
- x = x.to(softmax_dtype) * scale
28
  x = torch.where(mask, x, mask_value)
29
  x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
30
  return x
31
 
32
 
33
  @torch.jit.script
34
- def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
35
  input_dtype = x.dtype
36
- x = x.to(softmax_dtype) * scale
37
  x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
38
  return x
39
 
40
 
41
- @torch.jit.script
42
- def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
43
- x = torch.where(mask, x, mask_value)
44
- x = torch.nn.functional.softmax(x, dim=-1)
45
- return x
46
-
47
  @torch.jit.script
48
  def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
49
  """
@@ -76,7 +70,6 @@ def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
76
  m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
77
  # Concatenate the slopes with the remaining slopes.
78
  m = torch.cat([m, m_hat])
79
-
80
  return m
81
 
82
  @torch.jit.script
@@ -85,8 +78,7 @@ def get_alibi_biases(
85
  T: int,
86
  attn_heads: int,
87
  dev: torch.device,
88
- dtype: torch.dtype,
89
- causal: bool = True) -> torch.Tensor:
90
  """
91
  ## Calculate the attention biases matrix
92
  * `n_heads` is the number of heads in the attention layer
@@ -95,28 +87,26 @@ def get_alibi_biases(
95
  """
96
 
97
  # Get slopes $m$ for each head
98
- if causal:
99
- mask = (torch.triu(torch.ones((T, T), device=dev)) == 1).transpose(0, 1)
100
- else:
101
- mask = torch.ones((T, T), device=dev, dtype=torch.bool)
102
 
103
- m = _get_slopes(attn_heads, dev)
104
 
105
  # Calculate distances $[0, 1, \dots, N]$
106
  # Here we calculate the distances using the mask.
107
  #
108
  # Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
109
  # `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
110
- distance = mask.cumsum(dim=-1)
111
 
112
  # Multiply them pair-wise to get the AliBi bias matrix
113
  biases = distance[:, :, None] * m[None, None, :]
114
  biases = biases.permute(2, 0, 1)[None, :, :T, :T]
115
  biases = biases.repeat(B, 1, 1, 1)
116
- return biases.to(dtype).contiguous()
117
 
118
 
119
  class Attention(nn.Module):
 
120
  def __init__(self, config, layer_idx=None):
121
  super().__init__()
122
  self.mask_value = None
@@ -126,7 +116,7 @@ class Attention(nn.Module):
126
  self.head_dim = self.embed_dim // self.num_heads
127
  self.kv_attn_heads = 1
128
 
129
- self.scale = self.head_dim ** -0.5
130
 
131
  if self.head_dim * self.num_heads != self.embed_dim:
132
  raise ValueError(
@@ -139,41 +129,64 @@ class Attention(nn.Module):
139
  self.scale_attention_softmax_in_fp32 = (
140
  config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
141
  )
 
142
 
143
  self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
144
  self.k = nn.Linear(self.embed_dim, self.head_dim, bias=False)
145
  self.v = nn.Linear(self.embed_dim, self.head_dim, bias=False)
146
  self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
147
 
 
 
 
 
 
 
148
  def _attn(self, query, key, value, attention_mask=None, alibi=None):
149
  dtype = query.dtype
150
  softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
 
151
  upcast = dtype != softmax_dtype
152
- unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
153
 
154
- attn_weights = (alibi + torch.matmul(query * self.scale, key)).to(query.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  if upcast:
 
 
157
  if attention_mask is None:
158
- attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
159
  else:
160
- mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
161
- attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
162
  else:
163
  if attention_mask is not None:
164
- attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000)
165
-
166
  attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
167
 
168
- attn_output = torch.matmul(attn_weights, value)
169
 
170
  return attn_output, attn_weights
171
 
172
- def _split_heads(self, tensor):
173
- new_shape = tensor.shape[:-1] + (self.num_heads, self.head_dim)
174
- tensor = tensor.view(new_shape)
175
- return tensor.permute(0, 2, 1, 3)
176
-
177
  def forward(
178
  self,
179
  hidden_states: torch.Tensor,
@@ -186,13 +199,9 @@ class Attention(nn.Module):
186
  Tuple[torch.Tensor, Optional[torch.Tensor]],
187
  Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
188
  ]:
189
- b, t, _ = hidden_states.shape
190
  query = self.q(hidden_states)
191
  key = self.k(hidden_states)
192
  value = self.v(hidden_states)
193
- query = self._split_heads(query)
194
- key = key.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
195
- value = value.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
196
 
197
  if layer_past is not None:
198
  past_key, past_value = layer_past
@@ -205,18 +214,18 @@ class Attention(nn.Module):
205
  present = None
206
 
207
  attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
208
-
209
- attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
210
  attn_output = self.c_proj(attn_output)
211
 
212
  outputs = (attn_output, present)
213
  if output_attentions:
 
214
  outputs += (attn_weights,)
215
 
216
  return outputs # a, present, (attentions)
217
 
218
 
219
  class MLP(nn.Module):
 
220
  def __init__(self, intermediate_size, config, multiple_of: int = 256):
221
  super().__init__()
222
  embed_dim = config.hidden_size
@@ -227,7 +236,7 @@ class MLP(nn.Module):
227
  self.linear_3 = nn.Linear(embed_dim, hidden_dim, bias=False)
228
  self.c_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
229
 
230
- def forward(self, x: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
231
  x1 = F.silu(self.linear_1(x))
232
  x2 = self.linear_3(x)
233
  x = self.c_proj(x1 * x2)
@@ -297,6 +306,7 @@ class GPTRefactBlock(nn.Module):
297
 
298
 
299
  class GPTRefactPreTrainedModel(PreTrainedModel):
 
300
  config_class = GPTRefactConfig
301
  base_model_prefix = "transformer"
302
  supports_gradient_checkpointing = True
@@ -337,6 +347,7 @@ class GPTRefactPreTrainedModel(PreTrainedModel):
337
 
338
 
339
  class GPTRefactModel(GPTRefactPreTrainedModel):
 
340
  def __init__(self, config):
341
  super().__init__(config)
342
  self.embed_dim = config.hidden_size
@@ -347,6 +358,7 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
347
  self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
348
 
349
  self.max_positions = config.max_position_embeddings
 
350
  self.register_buffer(
351
  "bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
352
  persistent=False
@@ -357,16 +369,6 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
357
  # Initialize weights and apply final processing
358
  self.post_init()
359
 
360
- @staticmethod
361
- def _make_mask(seq_len: int, past_key_values_length: int):
362
- # prompt
363
- if past_key_values_length == 0:
364
- mask = torch.ones((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
365
- mask = torch.triu(mask, 1)
366
- else:
367
- mask = torch.zeros((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
368
- return mask
369
-
370
  def forward(
371
  self,
372
  input_ids: Optional[torch.Tensor] = None,
@@ -408,19 +410,25 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
408
  else:
409
  past_length = past_key_values[0][0].size(-2)
410
 
411
- # Self-attention mask.
412
  query_length = input_shape[-1]
413
-
414
  seq_length_with_past = past_length + query_length
415
- if attention_mask is None:
416
- attention_mask = self._make_mask(query_length, past_length).to(device)
417
- else:
418
- attention_mask = attention_mask.to(device)
 
 
 
 
 
 
 
419
 
420
  hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
421
 
 
422
  alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
423
- self.num_heads, device, torch.float32)[:, :, -query_length:, :]
424
 
425
  output_shape = input_shape + (hidden_states.size(-1),)
426
 
@@ -489,6 +497,7 @@ class GPTRefactModel(GPTRefactPreTrainedModel):
489
 
490
 
491
  class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
 
492
  _tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
493
 
494
  def __init__(self, config):
 
21
 
22
  @torch.jit.script
23
  def upcast_masked_softmax(
24
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, softmax_dtype: torch.dtype
25
  ):
26
  input_dtype = x.dtype
27
+ x = x.to(softmax_dtype)
28
  x = torch.where(mask, x, mask_value)
29
  x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
30
  return x
31
 
32
 
33
  @torch.jit.script
34
+ def upcast_softmax(x: torch.Tensor, softmax_dtype: torch.dtype):
35
  input_dtype = x.dtype
36
+ x = x.to(softmax_dtype)
37
  x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
38
  return x
39
 
40
 
 
 
 
 
 
 
41
  @torch.jit.script
42
  def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
43
  """
 
70
  m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
71
  # Concatenate the slopes with the remaining slopes.
72
  m = torch.cat([m, m_hat])
 
73
  return m
74
 
75
  @torch.jit.script
 
78
  T: int,
79
  attn_heads: int,
80
  dev: torch.device,
81
+ dtype: torch.dtype) -> torch.Tensor:
 
82
  """
83
  ## Calculate the attention biases matrix
84
  * `n_heads` is the number of heads in the attention layer
 
87
  """
88
 
89
  # Get slopes $m$ for each head
90
+ mask = torch.ones((T, T), device=dev, dtype=torch.bool)
 
 
 
91
 
92
+ m = _get_slopes(attn_heads, dev).to(dtype)
93
 
94
  # Calculate distances $[0, 1, \dots, N]$
95
  # Here we calculate the distances using the mask.
96
  #
97
  # Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
98
  # `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
99
+ distance = mask.cumsum(dim=-1).to(dtype)
100
 
101
  # Multiply them pair-wise to get the AliBi bias matrix
102
  biases = distance[:, :, None] * m[None, None, :]
103
  biases = biases.permute(2, 0, 1)[None, :, :T, :T]
104
  biases = biases.repeat(B, 1, 1, 1)
105
+ return biases.contiguous()
106
 
107
 
108
  class Attention(nn.Module):
109
+
110
  def __init__(self, config, layer_idx=None):
111
  super().__init__()
112
  self.mask_value = None
 
116
  self.head_dim = self.embed_dim // self.num_heads
117
  self.kv_attn_heads = 1
118
 
119
+ self.scale_factor = self.head_dim ** -0.5
120
 
121
  if self.head_dim * self.num_heads != self.embed_dim:
122
  raise ValueError(
 
129
  self.scale_attention_softmax_in_fp32 = (
130
  config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
131
  )
132
+ self.attention_bias_in_fp32 = config.attention_bias_in_fp32
133
 
134
  self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
135
  self.k = nn.Linear(self.embed_dim, self.head_dim, bias=False)
136
  self.v = nn.Linear(self.embed_dim, self.head_dim, bias=False)
137
  self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
138
 
139
+ def _get_mask_value(self, device, dtype):
140
+ # torch.where expects a tensor. We use a cache to avoid recreating it every time.
141
+ if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
142
+ self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
143
+ return self.mask_value
144
+
145
  def _attn(self, query, key, value, attention_mask=None, alibi=None):
146
  dtype = query.dtype
147
  softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
148
+ mask_value = self._get_mask_value(query.device, softmax_dtype)
149
  upcast = dtype != softmax_dtype
 
150
 
151
+ query_shape = query.shape
152
+ batch_size = query_shape[0]
153
+ key_length = key.size(-1)
154
+
155
+ # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
156
+ # -> (batch_size, query_length, num_heads, key_length)
157
+ query_length = query_shape[1]
158
+ attn_shape = (batch_size, query_length, self.num_heads, key_length)
159
+ attn_view = (batch_size, query_length * self.num_heads, key_length)
160
+ # No copy needed for MQA 2, or when layer_past is provided.
161
+ query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
162
+
163
+ alibi = alibi.transpose(2, 1).reshape(alibi.shape[0], -1, alibi.shape[-1])
164
+ initial_dtype = query.dtype
165
+ new_dtype = torch.float32 if self.attention_bias_in_fp32 else initial_dtype
166
+ attn_weights = alibi.baddbmm(
167
+ batch1=query.to(new_dtype),
168
+ batch2=key.to(new_dtype),
169
+ beta=1,
170
+ alpha=self.scale_factor
171
+ ).view(attn_shape).to(initial_dtype)
172
 
173
  if upcast:
174
+ # Use a fused kernel to prevent a large overhead from casting and scaling.
175
+ # Sub-optimal when the key length is not a multiple of 8.
176
  if attention_mask is None:
177
+ attn_weights = upcast_softmax(attn_weights, softmax_dtype)
178
  else:
179
+ attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, softmax_dtype)
 
180
  else:
181
  if attention_mask is not None:
182
+ # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
183
+ attn_weights = torch.where(attention_mask, attn_weights, mask_value)
184
  attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
185
 
186
+ attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
187
 
188
  return attn_output, attn_weights
189
 
 
 
 
 
 
190
  def forward(
191
  self,
192
  hidden_states: torch.Tensor,
 
199
  Tuple[torch.Tensor, Optional[torch.Tensor]],
200
  Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
201
  ]:
 
202
  query = self.q(hidden_states)
203
  key = self.k(hidden_states)
204
  value = self.v(hidden_states)
 
 
 
205
 
206
  if layer_past is not None:
207
  past_key, past_value = layer_past
 
214
  present = None
215
 
216
  attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
 
 
217
  attn_output = self.c_proj(attn_output)
218
 
219
  outputs = (attn_output, present)
220
  if output_attentions:
221
+ attn_weights = attn_weights.transpose(1, 2)
222
  outputs += (attn_weights,)
223
 
224
  return outputs # a, present, (attentions)
225
 
226
 
227
  class MLP(nn.Module):
228
+
229
  def __init__(self, intermediate_size, config, multiple_of: int = 256):
230
  super().__init__()
231
  embed_dim = config.hidden_size
 
236
  self.linear_3 = nn.Linear(embed_dim, hidden_dim, bias=False)
237
  self.c_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
238
 
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
  x1 = F.silu(self.linear_1(x))
241
  x2 = self.linear_3(x)
242
  x = self.c_proj(x1 * x2)
 
306
 
307
 
308
  class GPTRefactPreTrainedModel(PreTrainedModel):
309
+
310
  config_class = GPTRefactConfig
311
  base_model_prefix = "transformer"
312
  supports_gradient_checkpointing = True
 
347
 
348
 
349
  class GPTRefactModel(GPTRefactPreTrainedModel):
350
+
351
  def __init__(self, config):
352
  super().__init__(config)
353
  self.embed_dim = config.hidden_size
 
358
  self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
359
 
360
  self.max_positions = config.max_position_embeddings
361
+ self.attention_bias_in_fp32 = config.attention_bias_in_fp32
362
  self.register_buffer(
363
  "bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
364
  persistent=False
 
369
  # Initialize weights and apply final processing
370
  self.post_init()
371
 
 
 
 
 
 
 
 
 
 
 
372
  def forward(
373
  self,
374
  input_ids: Optional[torch.Tensor] = None,
 
410
  else:
411
  past_length = past_key_values[0][0].size(-2)
412
 
 
413
  query_length = input_shape[-1]
 
414
  seq_length_with_past = past_length + query_length
415
+
416
+ # Self-attention mask.
417
+ key_length = past_length + query_length
418
+ self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
419
+ if attention_mask is not None:
420
+ self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
421
+ dtype=torch.bool, device=self_attention_mask.device
422
+ )
423
+
424
+ # MQA models: (batch_size, query_length, n_heads, key_length)
425
+ attention_mask = self_attention_mask.unsqueeze(2)
426
 
427
  hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
428
 
429
+ alibi_dtype = torch.float32 if self.attention_bias_in_fp32 else self.wte.weight.dtype
430
  alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
431
+ self.num_heads, device, alibi_dtype)[:, :, -query_length:, :]
432
 
433
  output_shape = input_shape + (hidden_states.size(-1),)
434
 
 
497
 
498
 
499
  class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
500
+
501
  _tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
502
 
503
  def __init__(self, config):
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8c9761aabc16466fdf738d4fe42f12ee6844a360db07bde307ca808d0bfb6b8a
3
- size 6343461637
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1092c5efe56fe5b04360ba0d4ac231e8b03f9d1d0b8633b8ed678f73bdcb021a
3
+ size 3171776281