OpenNLPLab commited on
Commit
bb9fa21
·
verified ·
1 Parent(s): 98b2957

Upload modeling_transnormer.py

Browse files
Files changed (1) hide show
  1. modeling_transnormer.py +155 -163
modeling_transnormer.py CHANGED
@@ -53,8 +53,13 @@ logger = logging.get_logger(__name__)
53
 
54
  _CONFIG_FOR_DOC = "TransnormerConfig"
55
 
 
56
  use_triton = eval(os.environ.get("use_triton", default="True"))
57
  debug = eval(os.environ.get("debug", default="False"))
 
 
 
 
58
 
59
  if use_triton:
60
  try:
@@ -80,9 +85,11 @@ if not has_lightning_attention:
80
 
81
  return output
82
 
 
83
  ########## start Transnormer
84
  ##### Linearized Relative Positional Encoding: https://openreview.net/forum?id=xoLyps2qWc&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
85
  class Lrpe(nn.Module):
 
86
  def __init__(
87
  self,
88
  num_heads=8,
@@ -92,9 +99,8 @@ class Lrpe(nn.Module):
92
  d = num_heads * embed_dim
93
 
94
  self.index = torch.empty(0)
95
- self.theta = nn.Parameter(
96
- 10000 ** (-2 / d * torch.arange(d)).reshape(num_heads, 1, -1)
97
- )
98
 
99
  def extra_repr(self):
100
  return print_module(self)
@@ -113,6 +119,7 @@ class Lrpe(nn.Module):
113
 
114
 
115
  class GLU(nn.Module):
 
116
  def __init__(self, d1, d2, bias=False):
117
  super().__init__()
118
  if debug:
@@ -135,6 +142,7 @@ class GLU(nn.Module):
135
 
136
 
137
  class NormLinearAttention(nn.Module):
 
138
  def __init__(
139
  self,
140
  embed_dim,
@@ -181,7 +189,6 @@ class NormLinearAttention(nn.Module):
181
  use_cache: bool = False,
182
  slope_rate: Optional[torch.Tensor] = None,
183
  ):
184
- do_eval = eval(os.environ.get("do_eval", default="False"))
185
  if (not self.training) and (not do_eval):
186
  return self.inference(
187
  x,
@@ -198,8 +205,8 @@ class NormLinearAttention(nn.Module):
198
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
199
  # reshape
200
  q, k, v = map(
201
- lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads), [q, k, v]
202
- )
203
  # act
204
  q = self.act(q)
205
  k = self.act(k)
@@ -217,24 +224,23 @@ class NormLinearAttention(nn.Module):
217
  # lrpe
218
  if self.linear_use_lrpe:
219
  q = self.lrpe(q, offset=q_offset)
220
- k = self.lrpe(k)
221
 
222
  if attn_mask == None:
223
  attn_mask = (torch.tril(torch.ones(n, n))).to(q)
224
 
225
  if attn_padding_mask is not None:
226
  v = v.masked_fill(
227
- (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
228
- )
229
 
230
  if not has_lightning_attention:
231
  if slope_rate != None:
232
  attn_mask = torch.exp(slope_rate * attn_mask)
233
  output = linear_attention(q, k, v, attn_mask)
234
  else:
235
- output = lightning_attention(
236
- q, k, v, True, slope_rate.squeeze(-1).squeeze(-1)
237
- )
238
 
239
  # reshape
240
  output = rearrange(output, "b h n d -> b n (h d)")
@@ -253,14 +259,14 @@ class NormLinearAttention(nn.Module):
253
  return output, attn_weights, past_key_value
254
 
255
  def inference(
256
- self,
257
- x,
258
- attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
259
- attn_padding_mask: Optional[torch.Tensor] = None, # (b, m)
260
- output_attentions: bool = False,
261
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
262
- use_cache: bool = False,
263
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
264
  ):
265
  # x: b n d
266
  n = x.shape[-2]
@@ -268,8 +274,8 @@ class NormLinearAttention(nn.Module):
268
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
269
  # reshape
270
  q, k, v = map(
271
- lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads), [q, k, v]
272
- )
273
  # act
274
  q = self.act(q)
275
  k = self.act(k)
@@ -277,7 +283,7 @@ class NormLinearAttention(nn.Module):
277
  # rpe
278
  if self.linear_use_lrpe:
279
  q = self.lrpe(q, offset=self.offset)
280
- k = self.lrpe(k)
281
 
282
  if past_key_value == None:
283
  self.offset = q.shape[-2]
@@ -288,38 +294,47 @@ class NormLinearAttention(nn.Module):
288
 
289
  # only use for the first time
290
  if past_key_value == None:
291
- if attn_mask == None:
292
- attn_mask = (torch.tril(torch.ones(n, n))).to(q)
293
- if slope_rate != None:
294
- attn_mask = torch.exp(slope_rate * attn_mask)
295
-
296
  if attn_padding_mask is not None:
297
- attn_mask = attn_mask.masked_fill(
298
- (1 - attn_padding_mask).unsqueeze(1).unsqueeze(2).to(torch.bool),
299
- 0,
300
- )
301
- energy = torch.einsum("... n d, ... m d -> ... n m", q, k)
302
-
303
- if attn_mask != None:
304
- energy = energy * attn_mask
305
-
306
- output = torch.einsum("... n m, ... m d -> ... n d", energy, v)
307
-
308
- eval_and_not_generate = eval(
309
- os.environ.get("eval_and_not_generate", default="False")
310
- )
311
- if eval_and_not_generate:
312
- kv = None
313
- else:
314
- # b, h, n, e, d
315
- kv_outproduct = torch.einsum("... n e, ... n d -> ... n e d", k, v)
316
- # 1, 1, n, 1, 1
317
- index = torch.arange(n - 1, -1, -1).reshape(1, 1, -1, 1, 1).to(x)
318
- # (h, 1, 1) -> (1, h, 1, 1, 1); (1, h, 1, 1, 1), (1, 1, n, 1, 1) -> (1, h, n, 1, 1)
319
- decay = ratio.unsqueeze(0).unsqueeze(-1) ** index
320
-
321
- kv_outproduct_with_decay = kv_outproduct * decay
322
- kv = torch.sum(kv_outproduct_with_decay, dim=-3)
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  else:
324
  kv = past_key_value
325
 
@@ -327,12 +342,11 @@ class NormLinearAttention(nn.Module):
327
  for i in range(n):
328
  kv = ratio * kv + torch.einsum(
329
  "... n d, ... n e -> ... d e",
330
- k[:, :, i : i + 1],
331
- v[:, :, i : i + 1],
332
- )
333
- qkv = torch.einsum(
334
- "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv
335
  )
 
 
336
  output.append(qkv)
337
  output = torch.concat(output, dim=-2)
338
 
@@ -351,6 +365,7 @@ class NormLinearAttention(nn.Module):
351
 
352
 
353
  class TransnormerDecoderLayer(nn.Module):
 
354
  def __init__(self, config: TransnormerConfig):
355
  super().__init__()
356
  self.embed_dim = config.decoder_embed_dim
@@ -389,18 +404,18 @@ class TransnormerDecoderLayer(nn.Module):
389
  return residual + x
390
 
391
  def forward(
392
- self,
393
- x,
394
- attn_mask: Optional[torch.Tensor] = None,
395
- attn_padding_mask: Optional[torch.Tensor] = None,
396
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
397
- output_attentions: Optional[bool] = False,
398
- use_cache: Optional[bool] = False,
399
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
400
  ):
401
  residual = x
402
  input = x
403
-
404
  o1, self_attn_weights, present_key_value = self.token_mixer(
405
  x=self.token_norm(input),
406
  attn_mask=attn_mask,
@@ -418,10 +433,10 @@ class TransnormerDecoderLayer(nn.Module):
418
  outputs = (o, )
419
 
420
  if output_attentions:
421
- outputs += (self_attn_weights,)
422
 
423
  if use_cache:
424
- outputs += (present_key_value,)
425
 
426
  return outputs
427
 
@@ -443,9 +458,7 @@ TRANSNORMER_START_DOCSTRING = r"""
443
  """
444
 
445
 
446
- @add_start_docstrings(
447
- TRANSNORMER_START_DOCSTRING,
448
- )
449
  class TransnormerPreTrainedModel(PreTrainedModel):
450
  config_class = TransnormerConfig
451
  base_model_prefix = "model"
@@ -530,9 +543,7 @@ TRANSNORMER_INPUTS_DOCSTRING = r"""
530
  """
531
 
532
 
533
- @add_start_docstrings(
534
- TRANSNORMER_START_DOCSTRING,
535
- )
536
  class TransnormerModel(TransnormerPreTrainedModel):
537
  """
538
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TransnormerDecoderLayer`]
@@ -556,29 +567,31 @@ class TransnormerModel(TransnormerPreTrainedModel):
556
  self.slopes = self._build_slope_tensor(config.decoder_attention_heads)
557
 
558
  # params
559
- self.embed_tokens = nn.Embedding(
560
- config.vocab_size, config.decoder_embed_dim, self.padding_idx
561
- )
562
  self.layers = nn.ModuleList([])
563
  for i in range(config.decoder_layers):
564
  if len(self.linear_use_lrpe_list) > 0:
565
  config.linear_use_lrpe = self.linear_use_lrpe_list[i]
566
  self.layers.append(TransnormerDecoderLayer(config))
567
 
568
- self.final_norm = get_norm_fn(config.norm_type)(config.decoder_embed_dim)
 
569
  self.embed_dim = config.decoder_embed_dim
570
- self.embed_scale = (
571
- 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
572
- )
573
 
574
  # Initialize weights and apply final processing
575
  self.post_init()
576
 
577
  @staticmethod
578
  def _build_slope_tensor(n_attention_heads: int):
 
579
  def get_slopes(n):
 
580
  def get_slopes_power_of_2(n):
581
- start = 2 ** (-(2 ** -(math.log2(n) - 3)))
582
  ratio = start
583
  return [start * ratio**i for i in range(n)]
584
 
@@ -587,18 +600,15 @@ class TransnormerModel(TransnormerPreTrainedModel):
587
  n
588
  ) # In the paper, we only train models that have 2^a heads for some a. This function has
589
  else: # some good properties that only occur when the input is a power of 2. To maintain that even
590
- closest_power_of_2 = 2 ** math.floor(
591
  math.log2(n)
592
  ) # when the number of heads is not a power of 2, we use this workaround.
593
- return (
594
- get_slopes_power_of_2(closest_power_of_2)
595
- + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
596
- )
597
 
598
  # h, 1, 1
599
  slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
600
- n_attention_heads, 1, 1
601
- )
602
 
603
  return slopes
604
 
@@ -611,26 +621,26 @@ class TransnormerModel(TransnormerPreTrainedModel):
611
  def set_input_embeddings(self, value):
612
  self.embed_tokens = value
613
 
614
- def _prepare_decoder_linear_attn_mask(
615
- self, input_shape, inputs_embeds, past_key_values_length
616
- ):
617
  bsz, tgt_len = input_shape
618
  src_len = tgt_len + past_key_values_length
619
 
620
  def power_log(x):
621
- return 2 ** (math.ceil(math.log(x, 2)))
622
 
623
  n = power_log(max(tgt_len, src_len))
624
  if self._linear_attn_mask.shape[-1] < n:
625
 
626
  def get_mask(n):
627
- mask = torch.triu(torch.zeros(n, n).float().fill_(float("-inf")), 1)
 
628
  # no slope version
629
  # -n, ..., -2, -1, 0
630
  for i in range(n):
631
  x = torch.arange(i + 1)
632
  y = x
633
- mask[i, : i + 1] = -torch.flip(y, [0])
634
 
635
  return mask
636
 
@@ -642,7 +652,8 @@ class TransnormerModel(TransnormerPreTrainedModel):
642
  linear_attn_mask = self._linear_attn_mask[:, -tgt_len:, -src_len:]
643
  num_heads = linear_attn_mask.shape[0]
644
 
645
- return linear_attn_mask[None, :, :, :].expand(bsz, num_heads, tgt_len, src_len)
 
646
 
647
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
648
  def forward(
@@ -656,21 +667,15 @@ class TransnormerModel(TransnormerPreTrainedModel):
656
  output_hidden_states: Optional[bool] = None,
657
  return_dict: Optional[bool] = None,
658
  ) -> Union[Tuple, BaseModelOutputWithPast]:
659
- output_attentions = (
660
- output_attentions
661
- if output_attentions is not None
662
- else self.config.output_attentions
663
- )
664
- output_hidden_states = (
665
- output_hidden_states
666
- if output_hidden_states is not None
667
- else self.config.output_hidden_states
668
- )
669
  use_cache = use_cache if use_cache is not None else self.config.use_cache
670
 
671
- return_dict = (
672
- return_dict if return_dict is not None else self.config.use_return_dict
673
- )
674
 
675
  # retrieve input_ids and inputs_embeds
676
  if input_ids is not None and inputs_embeds is not None:
@@ -692,7 +697,7 @@ class TransnormerModel(TransnormerPreTrainedModel):
692
  if past_key_values is not None:
693
  past_key_values_length = past_key_values[0][0].shape[-2]
694
  seq_length_with_past = seq_length_with_past + past_key_values_length
695
-
696
  if inputs_embeds is None:
697
  # !!! use embed_scale
698
  inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
@@ -714,23 +719,23 @@ class TransnormerModel(TransnormerPreTrainedModel):
714
  ##### norm linear layers
715
  linear_attn_padding_mask = attn_padding_mask
716
  linear_attn_mask = self._prepare_decoder_linear_attn_mask(
717
- (batch_size, seq_length), inputs_embeds, past_key_values_length
718
- )
719
 
720
- slope_rates = [self.slopes.to(input_ids.device) for _ in range(self.num_layers)]
 
 
721
 
722
  for idx, layer in enumerate(self.layers):
723
  if output_hidden_states:
724
- all_hidden_states += (hidden_states,)
725
 
726
- past_key_value = (
727
- past_key_values[idx] if past_key_values is not None else None
728
- )
729
 
730
  slope_rate = slope_rates[idx]
731
  slope_rate = slope_rate * (1 - idx / (self.num_layers - 1) + 1e-5)
732
  mask = linear_attn_mask
733
-
734
  layer_outputs = layer(
735
  hidden_states,
736
  attn_mask=mask,
@@ -744,27 +749,24 @@ class TransnormerModel(TransnormerPreTrainedModel):
744
  hidden_states = layer_outputs[0]
745
 
746
  if use_cache:
747
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 
748
 
749
  if output_attentions:
750
- all_self_attns += (layer_outputs[1],)
751
-
752
- # if idx == 0:
753
- # break
754
 
755
  hidden_states = self.final_norm(hidden_states)
756
 
757
  # add hidden states from the last decoder layer
758
  if output_hidden_states:
759
- all_hidden_states += (hidden_states,)
760
 
761
  next_cache = next_decoder_cache if use_cache else None
762
  if not return_dict:
763
  return tuple(
764
- v
765
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
766
- if v is not None
767
- )
768
  return BaseModelOutputWithPast(
769
  last_hidden_state=hidden_states,
770
  past_key_values=next_cache,
@@ -774,6 +776,7 @@ class TransnormerModel(TransnormerPreTrainedModel):
774
 
775
 
776
  class TransnormerForCausalLM(TransnormerPreTrainedModel):
 
777
  def __init__(self, config):
778
  super().__init__(config)
779
  self.model = TransnormerModel(config)
@@ -781,9 +784,9 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
781
  logging_info(self.model)
782
 
783
  # the lm_head weight is automatically tied to the embed tokens weight
784
- self.lm_head = nn.Linear(
785
- config.decoder_embed_dim, config.vocab_size, bias=False
786
- )
787
 
788
  # Initialize weights and apply final processing
789
  self.post_init()
@@ -807,9 +810,8 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
807
  return self.model
808
 
809
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
810
- @replace_return_docstrings(
811
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
812
- )
813
  def forward(
814
  self,
815
  input_ids: torch.LongTensor = None,
@@ -847,19 +849,13 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
847
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
848
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
849
  ```"""
850
- output_attentions = (
851
- output_attentions
852
- if output_attentions is not None
853
- else self.config.output_attentions
854
- )
855
- output_hidden_states = (
856
- output_hidden_states
857
- if output_hidden_states is not None
858
- else self.config.output_hidden_states
859
- )
860
- return_dict = (
861
- return_dict if return_dict is not None else self.config.use_return_dict
862
- )
863
 
864
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
865
  outputs = self.model(
@@ -890,8 +886,8 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
890
  loss = loss_fct(shift_logits, shift_labels)
891
 
892
  if not return_dict:
893
- output = (logits,) + outputs[1:]
894
- return (loss,) + output if loss is not None else output
895
 
896
  return CausalLMOutputWithPast(
897
  loss=loss,
@@ -918,22 +914,18 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
918
  else:
919
  model_inputs = {"input_ids": input_ids}
920
 
921
- model_inputs.update(
922
- {
923
- "past_key_values": past_key_values,
924
- "use_cache": kwargs.get("use_cache"),
925
- "attention_mask": attention_mask,
926
- }
927
- )
928
  return model_inputs
929
 
930
  @staticmethod
931
  def _reorder_cache(past_key_values, beam_idx):
932
  reordered_past = ()
933
  for layer_past in past_key_values:
934
- reordered_past += (
935
- tuple(
936
- past_state.index_select(0, beam_idx) for past_state in layer_past
937
- ),
938
- )
939
  return reordered_past
 
53
 
54
  _CONFIG_FOR_DOC = "TransnormerConfig"
55
 
56
+ # TODO: fix environment: https://huggingface.co/OpenNLPLab/TransNormerLLM-7B/discussions/1
57
  use_triton = eval(os.environ.get("use_triton", default="True"))
58
  debug = eval(os.environ.get("debug", default="False"))
59
+ do_eval = eval(os.environ.get("do_eval", default="False"))
60
+ eval_and_not_generate = eval(
61
+ os.environ.get("eval_and_not_generate", default="False"))
62
+ BLOCK = 256
63
 
64
  if use_triton:
65
  try:
 
85
 
86
  return output
87
 
88
+
89
  ########## start Transnormer
90
  ##### Linearized Relative Positional Encoding: https://openreview.net/forum?id=xoLyps2qWc&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
91
  class Lrpe(nn.Module):
92
+
93
  def __init__(
94
  self,
95
  num_heads=8,
 
99
  d = num_heads * embed_dim
100
 
101
  self.index = torch.empty(0)
102
+ self.theta = nn.Parameter(10000**(-2 / d * torch.arange(d)).reshape(
103
+ num_heads, 1, -1))
 
104
 
105
  def extra_repr(self):
106
  return print_module(self)
 
119
 
120
 
121
  class GLU(nn.Module):
122
+
123
  def __init__(self, d1, d2, bias=False):
124
  super().__init__()
125
  if debug:
 
142
 
143
 
144
  class NormLinearAttention(nn.Module):
145
+
146
  def __init__(
147
  self,
148
  embed_dim,
 
189
  use_cache: bool = False,
190
  slope_rate: Optional[torch.Tensor] = None,
191
  ):
 
192
  if (not self.training) and (not do_eval):
193
  return self.inference(
194
  x,
 
205
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
206
  # reshape
207
  q, k, v = map(
208
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads),
209
+ [q, k, v])
210
  # act
211
  q = self.act(q)
212
  k = self.act(k)
 
224
  # lrpe
225
  if self.linear_use_lrpe:
226
  q = self.lrpe(q, offset=q_offset)
227
+ k = self.lrpe(k, offset=q_offset)
228
 
229
  if attn_mask == None:
230
  attn_mask = (torch.tril(torch.ones(n, n))).to(q)
231
 
232
  if attn_padding_mask is not None:
233
  v = v.masked_fill(
234
+ (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(
235
+ torch.bool), 0)
236
 
237
  if not has_lightning_attention:
238
  if slope_rate != None:
239
  attn_mask = torch.exp(slope_rate * attn_mask)
240
  output = linear_attention(q, k, v, attn_mask)
241
  else:
242
+ output = lightning_attention(q, k, v, True,
243
+ slope_rate.squeeze(-1).squeeze(-1))
 
244
 
245
  # reshape
246
  output = rearrange(output, "b h n d -> b n (h d)")
 
259
  return output, attn_weights, past_key_value
260
 
261
  def inference(
262
+ self,
263
+ x,
264
+ attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
265
+ attn_padding_mask: Optional[torch.Tensor] = None, # (b, m)
266
+ output_attentions: bool = False,
267
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
268
+ use_cache: bool = False,
269
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
270
  ):
271
  # x: b n d
272
  n = x.shape[-2]
 
274
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
275
  # reshape
276
  q, k, v = map(
277
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads),
278
+ [q, k, v])
279
  # act
280
  q = self.act(q)
281
  k = self.act(k)
 
283
  # rpe
284
  if self.linear_use_lrpe:
285
  q = self.lrpe(q, offset=self.offset)
286
+ k = self.lrpe(k, offset=self.offset)
287
 
288
  if past_key_value == None:
289
  self.offset = q.shape[-2]
 
294
 
295
  # only use for the first time
296
  if past_key_value == None:
297
+ slope_rate = slope_rate.to(torch.float32)
 
 
 
 
298
  if attn_padding_mask is not None:
299
+ v = v.masked_fill(
300
+ (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(
301
+ torch.bool), 0)
302
+ NUM_BLOCK = (n + BLOCK - 1) // BLOCK
303
+ b, h, n, d = q.shape
304
+ e = v.shape[-1]
305
+ # other
306
+ array = torch.arange(BLOCK).to(q) + 1 ## !!!! important
307
+ q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
308
+ k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
309
+ index = array[:, None] - array[None, :]
310
+ s_index = slope_rate * index[
311
+ None,
312
+ None,
313
+ ]
314
+ s_index = torch.where(index >= 0, -s_index, float("-inf"))
315
+ diag_decay = torch.exp(s_index)
316
+
317
+ kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
318
+ output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
319
+ for i in range(NUM_BLOCK):
320
+ si = i * BLOCK
321
+ ei = min(si + BLOCK, n)
322
+ m = ei - si
323
+
324
+ qi = q[:, :, si:ei].contiguous()
325
+ ki = k[:, :, si:ei].contiguous()
326
+ vi = v[:, :, si:ei].contiguous()
327
+ qkv_none_diag = torch.matmul(qi * q_decay[:, :m],
328
+ kv).to(torch.float32)
329
+
330
+ # diag
331
+ qk = torch.matmul(qi, ki.transpose(-1, -2)).to(
332
+ torch.float32) * diag_decay[:, :, :m, :m]
333
+ qkv_diag = torch.matmul(qk, vi.to(torch.float32))
334
+ block_decay = torch.exp(-slope_rate * m)
335
+ output[:, :, si:ei] = qkv_none_diag + qkv_diag
336
+ kv = block_decay * kv + torch.matmul(
337
+ (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi)
338
  else:
339
  kv = past_key_value
340
 
 
342
  for i in range(n):
343
  kv = ratio * kv + torch.einsum(
344
  "... n d, ... n e -> ... d e",
345
+ k[:, :, i:i + 1],
346
+ v[:, :, i:i + 1],
 
 
 
347
  )
348
+ qkv = torch.einsum("... n e, ... e d -> ... n d",
349
+ q[:, :, i:i + 1], kv)
350
  output.append(qkv)
351
  output = torch.concat(output, dim=-2)
352
 
 
365
 
366
 
367
  class TransnormerDecoderLayer(nn.Module):
368
+
369
  def __init__(self, config: TransnormerConfig):
370
  super().__init__()
371
  self.embed_dim = config.decoder_embed_dim
 
404
  return residual + x
405
 
406
  def forward(
407
+ self,
408
+ x,
409
+ attn_mask: Optional[torch.Tensor] = None,
410
+ attn_padding_mask: Optional[torch.Tensor] = None,
411
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
412
+ output_attentions: Optional[bool] = False,
413
+ use_cache: Optional[bool] = False,
414
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
415
  ):
416
  residual = x
417
  input = x
418
+
419
  o1, self_attn_weights, present_key_value = self.token_mixer(
420
  x=self.token_norm(input),
421
  attn_mask=attn_mask,
 
433
  outputs = (o, )
434
 
435
  if output_attentions:
436
+ outputs += (self_attn_weights, )
437
 
438
  if use_cache:
439
+ outputs += (present_key_value, )
440
 
441
  return outputs
442
 
 
458
  """
459
 
460
 
461
+ @add_start_docstrings(TRANSNORMER_START_DOCSTRING, )
 
 
462
  class TransnormerPreTrainedModel(PreTrainedModel):
463
  config_class = TransnormerConfig
464
  base_model_prefix = "model"
 
543
  """
544
 
545
 
546
+ @add_start_docstrings(TRANSNORMER_START_DOCSTRING, )
 
 
547
  class TransnormerModel(TransnormerPreTrainedModel):
548
  """
549
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TransnormerDecoderLayer`]
 
567
  self.slopes = self._build_slope_tensor(config.decoder_attention_heads)
568
 
569
  # params
570
+ self.embed_tokens = nn.Embedding(config.vocab_size,
571
+ config.decoder_embed_dim,
572
+ self.padding_idx)
573
  self.layers = nn.ModuleList([])
574
  for i in range(config.decoder_layers):
575
  if len(self.linear_use_lrpe_list) > 0:
576
  config.linear_use_lrpe = self.linear_use_lrpe_list[i]
577
  self.layers.append(TransnormerDecoderLayer(config))
578
 
579
+ self.final_norm = get_norm_fn(config.norm_type)(
580
+ config.decoder_embed_dim)
581
  self.embed_dim = config.decoder_embed_dim
582
+ self.embed_scale = (1.0 if config.no_scale_embedding else math.sqrt(
583
+ self.embed_dim))
 
584
 
585
  # Initialize weights and apply final processing
586
  self.post_init()
587
 
588
  @staticmethod
589
  def _build_slope_tensor(n_attention_heads: int):
590
+
591
  def get_slopes(n):
592
+
593
  def get_slopes_power_of_2(n):
594
+ start = 2**(-(2**-(math.log2(n) - 3)))
595
  ratio = start
596
  return [start * ratio**i for i in range(n)]
597
 
 
600
  n
601
  ) # In the paper, we only train models that have 2^a heads for some a. This function has
602
  else: # some good properties that only occur when the input is a power of 2. To maintain that even
603
+ closest_power_of_2 = 2**math.floor(
604
  math.log2(n)
605
  ) # when the number of heads is not a power of 2, we use this workaround.
606
+ return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
607
+ 2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
 
 
608
 
609
  # h, 1, 1
610
  slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
611
+ n_attention_heads, 1, 1)
 
612
 
613
  return slopes
614
 
 
621
  def set_input_embeddings(self, value):
622
  self.embed_tokens = value
623
 
624
+ def _prepare_decoder_linear_attn_mask(self, input_shape, inputs_embeds,
625
+ past_key_values_length):
 
626
  bsz, tgt_len = input_shape
627
  src_len = tgt_len + past_key_values_length
628
 
629
  def power_log(x):
630
+ return 2**(math.ceil(math.log(x, 2)))
631
 
632
  n = power_log(max(tgt_len, src_len))
633
  if self._linear_attn_mask.shape[-1] < n:
634
 
635
  def get_mask(n):
636
+ mask = torch.triu(
637
+ torch.zeros(n, n).float().fill_(float("-inf")), 1)
638
  # no slope version
639
  # -n, ..., -2, -1, 0
640
  for i in range(n):
641
  x = torch.arange(i + 1)
642
  y = x
643
+ mask[i, :i + 1] = -torch.flip(y, [0])
644
 
645
  return mask
646
 
 
652
  linear_attn_mask = self._linear_attn_mask[:, -tgt_len:, -src_len:]
653
  num_heads = linear_attn_mask.shape[0]
654
 
655
+ return linear_attn_mask[None, :, :, :].expand(bsz, num_heads, tgt_len,
656
+ src_len)
657
 
658
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
659
  def forward(
 
667
  output_hidden_states: Optional[bool] = None,
668
  return_dict: Optional[bool] = None,
669
  ) -> Union[Tuple, BaseModelOutputWithPast]:
670
+ output_attentions = (output_attentions if output_attentions is not None
671
+ else self.config.output_attentions)
672
+ output_hidden_states = (output_hidden_states
673
+ if output_hidden_states is not None else
674
+ self.config.output_hidden_states)
 
 
 
 
 
675
  use_cache = use_cache if use_cache is not None else self.config.use_cache
676
 
677
+ return_dict = (return_dict if return_dict is not None else
678
+ self.config.use_return_dict)
 
679
 
680
  # retrieve input_ids and inputs_embeds
681
  if input_ids is not None and inputs_embeds is not None:
 
697
  if past_key_values is not None:
698
  past_key_values_length = past_key_values[0][0].shape[-2]
699
  seq_length_with_past = seq_length_with_past + past_key_values_length
700
+
701
  if inputs_embeds is None:
702
  # !!! use embed_scale
703
  inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
 
719
  ##### norm linear layers
720
  linear_attn_padding_mask = attn_padding_mask
721
  linear_attn_mask = self._prepare_decoder_linear_attn_mask(
722
+ (batch_size, seq_length), inputs_embeds, past_key_values_length)
 
723
 
724
+ slope_rates = [
725
+ self.slopes.to(input_ids.device) for _ in range(self.num_layers)
726
+ ]
727
 
728
  for idx, layer in enumerate(self.layers):
729
  if output_hidden_states:
730
+ all_hidden_states += (hidden_states, )
731
 
732
+ past_key_value = (past_key_values[idx]
733
+ if past_key_values is not None else None)
 
734
 
735
  slope_rate = slope_rates[idx]
736
  slope_rate = slope_rate * (1 - idx / (self.num_layers - 1) + 1e-5)
737
  mask = linear_attn_mask
738
+
739
  layer_outputs = layer(
740
  hidden_states,
741
  attn_mask=mask,
 
749
  hidden_states = layer_outputs[0]
750
 
751
  if use_cache:
752
+ next_decoder_cache += (
753
+ layer_outputs[2 if output_attentions else 1], )
754
 
755
  if output_attentions:
756
+ all_self_attns += (layer_outputs[1], )
 
 
 
757
 
758
  hidden_states = self.final_norm(hidden_states)
759
 
760
  # add hidden states from the last decoder layer
761
  if output_hidden_states:
762
+ all_hidden_states += (hidden_states, )
763
 
764
  next_cache = next_decoder_cache if use_cache else None
765
  if not return_dict:
766
  return tuple(
767
+ v for v in
768
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
769
+ if v is not None)
 
770
  return BaseModelOutputWithPast(
771
  last_hidden_state=hidden_states,
772
  past_key_values=next_cache,
 
776
 
777
 
778
  class TransnormerForCausalLM(TransnormerPreTrainedModel):
779
+
780
  def __init__(self, config):
781
  super().__init__(config)
782
  self.model = TransnormerModel(config)
 
784
  logging_info(self.model)
785
 
786
  # the lm_head weight is automatically tied to the embed tokens weight
787
+ self.lm_head = nn.Linear(config.decoder_embed_dim,
788
+ config.vocab_size,
789
+ bias=False)
790
 
791
  # Initialize weights and apply final processing
792
  self.post_init()
 
810
  return self.model
811
 
812
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
813
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast,
814
+ config_class=_CONFIG_FOR_DOC)
 
815
  def forward(
816
  self,
817
  input_ids: torch.LongTensor = None,
 
849
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
850
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
851
  ```"""
852
+ output_attentions = (output_attentions if output_attentions is not None
853
+ else self.config.output_attentions)
854
+ output_hidden_states = (output_hidden_states
855
+ if output_hidden_states is not None else
856
+ self.config.output_hidden_states)
857
+ return_dict = (return_dict if return_dict is not None else
858
+ self.config.use_return_dict)
 
 
 
 
 
 
859
 
860
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
861
  outputs = self.model(
 
886
  loss = loss_fct(shift_logits, shift_labels)
887
 
888
  if not return_dict:
889
+ output = (logits, ) + outputs[1:]
890
+ return (loss, ) + output if loss is not None else output
891
 
892
  return CausalLMOutputWithPast(
893
  loss=loss,
 
914
  else:
915
  model_inputs = {"input_ids": input_ids}
916
 
917
+ model_inputs.update({
918
+ "past_key_values": past_key_values,
919
+ "use_cache": kwargs.get("use_cache"),
920
+ "attention_mask": attention_mask,
921
+ })
 
 
922
  return model_inputs
923
 
924
  @staticmethod
925
  def _reorder_cache(past_key_values, beam_idx):
926
  reordered_past = ()
927
  for layer_past in past_key_values:
928
+ reordered_past += (tuple(
929
+ past_state.index_select(0, beam_idx)
930
+ for past_state in layer_past), )
 
 
931
  return reordered_past