bys0318 commited on
Commit
6eff855
·
verified ·
1 Parent(s): 81b025e

Modify to original GLM-4-9B code

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +298 -132
modeling_chatglm.py CHANGED
@@ -1,42 +1,39 @@
1
  """ PyTorch ChatGLM model. """
2
 
3
  import math
4
- import copy
5
- import warnings
6
- import re
7
  import sys
8
-
9
  import torch
10
  import torch.utils.checkpoint
11
  import torch.nn.functional as F
12
  from torch import nn
13
- from torch.nn import CrossEntropyLoss, LayerNorm
14
  from torch.nn.utils import skip_init
15
- from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
 
17
  from transformers.modeling_outputs import (
18
  BaseModelOutputWithPast,
19
  CausalLMOutputWithPast,
 
20
  )
21
  from transformers.modeling_utils import PreTrainedModel
22
- from transformers.utils import logging
23
  from transformers.generation.logits_process import LogitsProcessor
24
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
25
 
26
  from .configuration_chatglm import ChatGLMConfig
27
- from einops import rearrange
28
  try:
29
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func
30
- except ImportError:
31
- try:
32
- # FlashAttention-2
33
- from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
34
- except ImportError:
35
- flash_attn_unpadded_func = None
36
 
37
  # flags required to enable jit fusion kernels
38
 
39
- if sys.platform != 'darwin':
40
  torch._C._jit_set_profiling_mode(False)
41
  torch._C._jit_set_profiling_executor(False)
42
  torch._C._jit_override_can_fuse_on_cpu(True)
@@ -44,13 +41,9 @@ if sys.platform != 'darwin':
44
 
45
  logger = logging.get_logger(__name__)
46
 
47
- _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
48
- _CONFIG_FOR_DOC = "ChatGLM6BConfig"
49
 
50
- CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
51
- "THUDM/chatglm2-6b",
52
- # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
53
- ]
54
 
55
  def default_init(cls, *args, **kwargs):
56
  return cls(*args, **kwargs)
@@ -60,22 +53,21 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
60
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
61
  if torch.isnan(scores).any() or torch.isinf(scores).any():
62
  scores.zero_()
63
- scores[..., 5] = 5e4
64
  return scores
65
 
 
66
  def split_tensor_along_last_dim(
67
  tensor: torch.Tensor,
68
  num_partitions: int,
69
  contiguous_split_chunks: bool = False,
70
  ) -> List[torch.Tensor]:
71
  """Split a tensor along its last dimension.
72
-
73
  Arguments:
74
  tensor: input tensor.
75
  num_partitions: number of partitions to split the tensor
76
  contiguous_split_chunks: If True, make each chunk contiguous
77
  in memory.
78
-
79
  Returns:
80
  A list of Tensors
81
  """
@@ -104,13 +96,11 @@ class RotaryEmbedding(nn.Module):
104
  self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
105
  ):
106
  """Enhanced Transformer with Rotary Position Embedding.
107
-
108
  Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
109
  transformers/rope/__init__.py. MIT License:
110
  https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
111
  """
112
  # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
113
-
114
  base = base * self.rope_ratio
115
  theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
116
 
@@ -135,14 +125,14 @@ class RotaryEmbedding(nn.Module):
135
 
136
  @torch.jit.script
137
  def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
138
- # x: [sq, b, np, hn]
139
- sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
140
  rot_dim = rope_cache.shape[-2] * 2
141
  x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
142
  # truncate to support variable sizes
143
- rope_cache = rope_cache[:sq]
144
- xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
145
- rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
146
  x_out2 = torch.stack(
147
  [
148
  xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
@@ -171,12 +161,13 @@ class RMSNorm(torch.nn.Module):
171
  class CoreAttention(torch.nn.Module):
172
  def __init__(self, config: ChatGLMConfig, layer_number):
173
  super(CoreAttention, self).__init__()
174
-
175
  self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
176
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
177
  if self.apply_query_key_layer_scaling:
178
  self.attention_softmax_in_fp32 = True
179
  self.layer_number = max(1, layer_number)
 
180
 
181
  projection_size = config.kv_channels * config.num_attention_heads
182
 
@@ -185,43 +176,213 @@ class CoreAttention(torch.nn.Module):
185
  self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
186
  self.num_attention_heads_per_partition = config.num_attention_heads
187
 
 
188
  self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
189
- self.attention_dropout = config.attention_dropout
 
 
 
 
 
190
 
191
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
192
- seqlen_q, batch_size = query_layer.shape[0], query_layer.shape[1]
193
- seqlen_k = key_layer.shape[0]
194
- query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> (b s) ...') for x in [query_layer, key_layer, value_layer]]
195
- # DO flash_attn_varlen_func
196
- if attention_mask is None or attention_mask.ndim != 1:
197
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
198
- device=query_layer.device)
199
- else:
200
- assert seqlen_q == seqlen_k
201
- cu_seqlens_q = attention_mask
202
- if self.training:
203
- assert seqlen_k == seqlen_q
204
- is_causal = True
205
- cu_seqlens_k = cu_seqlens_q
206
- else:
207
- is_causal = seqlen_q == seqlen_k
208
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
209
- device=query_layer.device) if not is_causal else cu_seqlens_q
210
- self.attention_dropout = 0
211
- context_layer = flash_attn_unpadded_func(
212
- query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
213
- self.attention_dropout,
214
- softmax_scale=1.0 / self.norm_factor, causal=is_causal
215
  )
216
- context_layer = rearrange(context_layer, '(b s) ... -> s b ...', b=batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
218
  context_layer = context_layer.reshape(*new_context_layer_shape)
219
  return context_layer
220
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  class SelfAttention(torch.nn.Module):
223
  """Parallel self-attention layer abstract class.
224
-
225
  Self-attention layer takes input with size [s, b, h]
226
  and returns output of the same size.
227
  """
@@ -248,7 +409,7 @@ class SelfAttention(torch.nn.Module):
248
  device=device, **_config_to_kwargs(config)
249
  )
250
 
251
- self.core_attention = CoreAttention(config, self.layer_number)
252
 
253
  # Output.
254
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
@@ -272,7 +433,7 @@ class SelfAttention(torch.nn.Module):
272
  def forward(
273
  self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
274
  ):
275
- # hidden_states: [sq, b, h]
276
 
277
  # =================================================
278
  # Pre-allocate memory for key-values for inference.
@@ -281,7 +442,7 @@ class SelfAttention(torch.nn.Module):
281
  # Query, Key, and Value
282
  # =====================
283
 
284
- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
285
  mixed_x_layer = self.query_key_value(hidden_states)
286
 
287
  if self.multi_query_attention:
@@ -309,39 +470,45 @@ class SelfAttention(torch.nn.Module):
309
  3 * self.hidden_size_per_attention_head)
310
  mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
311
 
312
- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
313
  (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
314
 
 
 
 
315
  # apply relative positional encoding (rotary embedding)
316
  if rotary_pos_emb is not None:
317
  query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
318
  key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
319
 
320
  # adjust key and value for inference
 
 
 
 
321
  if use_cache:
322
- if kv_cache is not None:
323
- cache_k, cache_v = kv_cache
324
- key_layer = torch.cat((cache_k, key_layer), dim=0)
325
- value_layer = torch.cat((cache_v, value_layer), dim=0)
326
- kv_cache = (key_layer, value_layer)
327
  else:
328
  kv_cache = None
329
-
330
-
331
  if self.multi_query_attention:
332
- key_layer = key_layer.unsqueeze(-2)
333
  key_layer = key_layer.expand(
334
- -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
335
  )
336
  key_layer = key_layer.contiguous().view(
337
- key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
338
  )
339
- value_layer = value_layer.unsqueeze(-2)
340
  value_layer = value_layer.expand(
341
- -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
342
  )
343
  value_layer = value_layer.contiguous().view(
344
- value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
345
  )
346
 
347
  # ==================================
@@ -368,7 +535,6 @@ def _config_to_kwargs(args):
368
 
369
  class MLP(torch.nn.Module):
370
  """MLP.
371
-
372
  MLP will take the input with h hidden state, project it to 4*h
373
  hidden dimension, perform nonlinear transformation, and project the
374
  state back into h hidden dimension.
@@ -414,7 +580,6 @@ class MLP(torch.nn.Module):
414
 
415
  class GLMBlock(torch.nn.Module):
416
  """A single transformer layer.
417
-
418
  Transformer layer takes input with size [s, b, h] and returns an
419
  output of the same size.
420
  """
@@ -525,9 +690,9 @@ class GLMTransformer(torch.nn.Module):
525
  presents = () if use_cache else None
526
  if self.gradient_checkpointing and self.training:
527
  if use_cache:
528
- # logger.warning_once(
529
- # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
- # )
531
  use_cache = False
532
 
533
  all_self_attentions = None
@@ -557,7 +722,15 @@ class GLMTransformer(torch.nn.Module):
557
  )
558
  hidden_states, kv_cache = layer_ret
559
  if use_cache:
560
- presents = presents + (kv_cache,)
 
 
 
 
 
 
 
 
561
 
562
  if output_hidden_states:
563
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -580,18 +753,24 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
580
  config_class = ChatGLMConfig
581
  base_model_prefix = "transformer"
582
  _no_split_modules = ["GLMBlock"]
 
 
583
 
584
  def _init_weights(self, module: nn.Module):
585
  """Initialize the weights."""
586
  return
587
 
588
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
 
 
 
 
589
  batch_size, seq_length = input_ids.shape
590
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
591
  full_attention_mask.tril_()
592
  past_length = 0
593
  if past_key_values:
594
- past_length = past_key_values[0][0].shape[0]
595
  if past_length:
596
  full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
597
  device=input_ids.device), full_attention_mask), dim=-1)
@@ -608,11 +787,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
608
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
609
  return position_ids
610
 
611
- def _set_gradient_checkpointing(self, module, value=False):
612
- if isinstance(module, GLMTransformer):
613
- module.gradient_checkpointing = value
614
-
615
-
616
  class Embedding(torch.nn.Module):
617
  """Language model embeddings."""
618
 
@@ -633,8 +807,6 @@ class Embedding(torch.nn.Module):
633
  # Embeddings.
634
  words_embeddings = self.word_embeddings(input_ids)
635
  embeddings = words_embeddings
636
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
637
- embeddings = embeddings.transpose(0, 1).contiguous()
638
  # If the input flag for fp32 residual connection is set, convert for float.
639
  if self.fp32_residual_connection:
640
  embeddings = embeddings.float()
@@ -652,6 +824,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
652
  if device is not None:
653
  init_kwargs["device"] = device
654
  self.embedding = init_method(Embedding, config, **init_kwargs)
 
 
 
655
 
656
  # Rotary positional embeddings
657
  self.seq_length = config.seq_length
@@ -659,7 +834,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
659
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
660
  )
661
 
662
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
 
663
  device=device, dtype=config.torch_dtype)
664
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
665
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
@@ -668,6 +844,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
668
  def get_input_embeddings(self):
669
  return self.embedding.word_embeddings
670
 
 
 
 
671
  def forward(
672
  self,
673
  input_ids,
@@ -677,6 +856,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
677
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
678
  inputs_embeds: Optional[torch.Tensor] = None,
679
  use_cache: Optional[bool] = None,
 
680
  output_hidden_states: Optional[bool] = None,
681
  return_dict: Optional[bool] = None,
682
  ):
@@ -691,9 +871,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
691
  if inputs_embeds is None:
692
  inputs_embeds = self.embedding(input_ids)
693
 
694
- # if full_attention_mask is None:
695
- # if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
696
- # full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
697
 
698
  # Rotary positional embeddings
699
  rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
@@ -701,13 +881,18 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
701
  rotary_pos_emb = rotary_pos_emb[position_ids]
702
  else:
703
  rotary_pos_emb = rotary_pos_emb[None, :seq_length]
704
- rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
705
 
706
  # Run encoder.
707
  hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
708
- inputs_embeds, attention_mask, rotary_pos_emb=rotary_pos_emb,
709
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
710
  )
 
 
 
 
 
 
711
 
712
  if not return_dict:
713
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
@@ -727,7 +912,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
727
  self.max_sequence_length = config.max_length
728
  self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
729
  self.config = config
730
- self.pack_loss = False
731
 
732
  def _update_model_kwargs_for_generation(
733
  self,
@@ -764,6 +948,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
764
  past_key_values: Optional[torch.Tensor] = None,
765
  attention_mask: Optional[torch.Tensor] = None,
766
  position_ids: Optional[torch.Tensor] = None,
 
767
  is_first_forward: bool = True,
768
  **kwargs
769
  ) -> dict:
@@ -771,14 +956,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
771
  if position_ids is None:
772
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
773
  if not is_first_forward:
774
- position_ids = position_ids[..., -1:]
775
- input_ids = input_ids[:, -1:]
 
776
  return {
777
  "input_ids": input_ids,
778
  "past_key_values": past_key_values,
779
  "position_ids": position_ids,
780
  "attention_mask": attention_mask,
781
- "return_last_logit": True
 
782
  }
783
 
784
  def forward(
@@ -788,7 +975,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
788
  attention_mask: Optional[torch.Tensor] = None,
789
  past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
790
  inputs_embeds: Optional[torch.Tensor] = None,
791
- labels: Optional[Tuple[torch.Tensor]] = None,
792
  use_cache: Optional[bool] = None,
793
  output_attentions: Optional[bool] = None,
794
  output_hidden_states: Optional[bool] = None,
@@ -811,30 +998,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
811
 
812
  hidden_states = transformer_outputs[0]
813
  if return_last_logit:
814
- hidden_states = hidden_states[-1:]
815
  lm_logits = self.transformer.output_layer(hidden_states)
816
- lm_logits = lm_logits.transpose(0, 1).contiguous()
817
 
818
  loss = None
819
  if labels is not None:
820
  lm_logits = lm_logits.to(torch.float32)
 
821
  # Shift so that tokens < n predict n
822
  shift_logits = lm_logits[..., :-1, :].contiguous()
823
- if isinstance(labels, tuple) or isinstance(labels, list):
824
- labels, weights = labels
825
  shift_labels = labels[..., 1:].contiguous()
826
- if self.pack_loss:
827
- loss_fct = CrossEntropyLoss(ignore_index=-100)#, reduction='none')
828
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
829
- loss *= weights
830
- # if self.pack_loss:
831
- # shift_weights = weights[..., 1:].contiguous()
832
- # loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
833
- # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
834
- # loss = (loss * shift_weights).sum()
835
- else:
836
- loss_fct = CrossEntropyLoss(ignore_index=-100)
837
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
838
 
839
  lm_logits = lm_logits.to(hidden_states.dtype)
840
  loss = loss.to(hidden_states.dtype)
@@ -859,33 +1035,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
859
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
860
  [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
861
  beam_idx at every generation step.
862
-
863
  Output shares the same memory storage as `past`.
864
  """
865
  return tuple(
866
  (
867
- layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
868
- layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
869
  )
870
  for layer_past in past
871
  )
872
 
873
- def process_response(self, response):
874
- response = response.strip()
875
- response = response.replace("[[训练时间]]", "2023年")
876
- return response
877
-
878
  @torch.inference_mode()
879
  def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
880
- max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
881
  **kwargs):
882
  if history is None:
883
  history = []
884
- if logits_processor is None:
885
- logits_processor = LogitsProcessorList()
886
- logits_processor.append(InvalidScoreLogitsProcessor())
887
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
888
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
889
  inputs = tokenizer.build_chat_input(query, history=history, role=role)
890
  inputs = inputs.to(self.device)
891
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
@@ -894,5 +1061,4 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
894
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
895
  response = tokenizer.decode(outputs)
896
  history.append({"role": role, "content": query})
897
- response = self.process_response(response)
898
  return response, history
 
1
  """ PyTorch ChatGLM model. """
2
 
3
  import math
 
 
 
4
  import sys
 
5
  import torch
6
  import torch.utils.checkpoint
7
  import torch.nn.functional as F
8
  from torch import nn
9
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
10
  from torch.nn.utils import skip_init
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
 
13
  from transformers.modeling_outputs import (
14
  BaseModelOutputWithPast,
15
  CausalLMOutputWithPast,
16
+ SequenceClassifierOutputWithPast,
17
  )
18
  from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging, is_torch_npu_available
20
  from transformers.generation.logits_process import LogitsProcessor
21
+ from transformers.generation.utils import ModelOutput
22
 
23
  from .configuration_chatglm import ChatGLMConfig
24
+
25
  try:
26
+ from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+ except:
32
+ pass
33
 
34
  # flags required to enable jit fusion kernels
35
 
36
+ if sys.platform != 'darwin' and not is_torch_npu_available():
37
  torch._C._jit_set_profiling_mode(False)
38
  torch._C._jit_set_profiling_executor(False)
39
  torch._C._jit_override_can_fuse_on_cpu(True)
 
41
 
42
  logger = logging.get_logger(__name__)
43
 
44
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
45
+ _CONFIG_FOR_DOC = "ChatGLMConfig"
46
 
 
 
 
 
47
 
48
  def default_init(cls, *args, **kwargs):
49
  return cls(*args, **kwargs)
 
53
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
54
  if torch.isnan(scores).any() or torch.isinf(scores).any():
55
  scores.zero_()
56
+ scores[..., 198] = 5e4
57
  return scores
58
 
59
+
60
  def split_tensor_along_last_dim(
61
  tensor: torch.Tensor,
62
  num_partitions: int,
63
  contiguous_split_chunks: bool = False,
64
  ) -> List[torch.Tensor]:
65
  """Split a tensor along its last dimension.
 
66
  Arguments:
67
  tensor: input tensor.
68
  num_partitions: number of partitions to split the tensor
69
  contiguous_split_chunks: If True, make each chunk contiguous
70
  in memory.
 
71
  Returns:
72
  A list of Tensors
73
  """
 
96
  self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
97
  ):
98
  """Enhanced Transformer with Rotary Position Embedding.
 
99
  Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
100
  transformers/rope/__init__.py. MIT License:
101
  https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
102
  """
103
  # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
 
104
  base = base * self.rope_ratio
105
  theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
106
 
 
125
 
126
  @torch.jit.script
127
  def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
128
+ # x: [b, np, sq, hn]
129
+ b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
130
  rot_dim = rope_cache.shape[-2] * 2
131
  x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
132
  # truncate to support variable sizes
133
+ rope_cache = rope_cache[:, :sq]
134
+ xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
135
+ rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
136
  x_out2 = torch.stack(
137
  [
138
  xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
 
161
  class CoreAttention(torch.nn.Module):
162
  def __init__(self, config: ChatGLMConfig, layer_number):
163
  super(CoreAttention, self).__init__()
164
+ self.config = config
165
  self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
166
  self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
167
  if self.apply_query_key_layer_scaling:
168
  self.attention_softmax_in_fp32 = True
169
  self.layer_number = max(1, layer_number)
170
+ self.is_causal = True
171
 
172
  projection_size = config.kv_channels * config.num_attention_heads
173
 
 
176
  self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
177
  self.num_attention_heads_per_partition = config.num_attention_heads
178
 
179
+ coeff = None
180
  self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
181
+ if self.apply_query_key_layer_scaling:
182
+ coeff = self.layer_number
183
+ self.norm_factor *= coeff
184
+ self.coeff = coeff
185
+
186
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
187
 
188
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
189
+ # [b, np, sq, sk]
190
+ output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
191
+
192
+ # [b, np, sq, hn] -> [b * np, sq, hn]
193
+ query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
194
+ # [b, np, sk, hn] -> [b * np, sk, hn]
195
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
196
+
197
+ # preallocting input tensor: [b * np, sq, sk]
198
+ matmul_input_buffer = torch.empty(
199
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
200
+ device=query_layer.device
201
+ )
202
+
203
+ # Raw attention scores. [b * np, sq, sk]
204
+ matmul_result = torch.baddbmm(
205
+ matmul_input_buffer,
206
+ query_layer, # [b * np, sq, hn]
207
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
208
+ beta=0.0,
209
+ alpha=(1.0 / self.norm_factor),
 
 
210
  )
211
+
212
+ # change view to [b, np, sq, sk]
213
+ attention_scores = matmul_result.view(*output_size)
214
+
215
+ # ===========================
216
+ # Attention probs and dropout
217
+ # ===========================
218
+
219
+ # attention scores and attention mask [b, np, sq, sk]
220
+ if self.attention_softmax_in_fp32:
221
+ attention_scores = attention_scores.float()
222
+ if self.coeff is not None:
223
+ attention_scores = attention_scores * self.coeff
224
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
225
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
226
+ device=attention_scores.device, dtype=torch.bool)
227
+ attention_mask.tril_()
228
+ attention_mask = ~attention_mask
229
+ if attention_mask is not None:
230
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
231
+ attention_probs = F.softmax(attention_scores, dim=-1)
232
+ attention_probs = attention_probs.type_as(value_layer)
233
+
234
+ # This is actually dropping out entire tokens to attend to, which might
235
+ # seem a bit unusual, but is taken from the original Transformer paper.
236
+ attention_probs = self.attention_dropout(attention_probs)
237
+
238
+ # query layer shape: [b * np, sq, hn]
239
+ # value layer shape: [b, np, sk, hn]
240
+ # attention shape: [b, np, sq, sk]
241
+ # context layer shape: [b, np, sq, hn]
242
+ output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
243
+ # change view [b * np, sk, hn]
244
+ value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
245
+ # change view [b * np, sq, sk]
246
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
247
+ # matmul: [b * np, sq, hn]
248
+ context_layer = torch.bmm(attention_probs, value_layer)
249
+ # change view [b, np, sq, hn]
250
+ context_layer = context_layer.view(*output_size)
251
+ # [b, np, sq, hn] --> [b, sq, np, hn]
252
+ context_layer = context_layer.transpose(1, 2).contiguous()
253
+ # [b, sq, np, hn] --> [b, sq, hp]
254
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
255
+ context_layer = context_layer.reshape(*new_context_layer_shape)
256
+
257
+ return context_layer
258
+
259
+
260
+ class SdpaAttention(CoreAttention):
261
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
262
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
263
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
264
+ is_causal=True,
265
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
266
+ else:
267
+ if attention_mask is not None:
268
+ attention_mask = ~attention_mask
269
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
270
+ attention_mask,
271
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
272
+ context_layer = context_layer.transpose(1, 2).contiguous()
273
  new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
274
  context_layer = context_layer.reshape(*new_context_layer_shape)
275
  return context_layer
276
 
277
 
278
+ def _get_unpad_data(attention_mask):
279
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
280
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
281
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
282
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
283
+ return (
284
+ indices,
285
+ cu_seqlens,
286
+ max_seqlen_in_batch,
287
+ )
288
+
289
+
290
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
291
+ class FlashAttention2(CoreAttention):
292
+ def __init__(self, *args, **kwargs):
293
+ super().__init__(*args, **kwargs)
294
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
295
+
296
+ def forward(self, query_states, key_states, value_states, attention_mask):
297
+ query_states = query_states.transpose(1, 2)
298
+ key_states = key_states.transpose(1, 2)
299
+ value_states = value_states.transpose(1, 2)
300
+ batch_size, query_length = query_states.shape[:2]
301
+ if not self._flash_attn_uses_top_left_mask:
302
+ causal = self.is_causal
303
+ else:
304
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
305
+ causal = self.is_causal and query_length != 1
306
+ dropout = self.config.attention_dropout if self.training else 0.0
307
+ # Contains at least one padding token in the sequence
308
+ if attention_mask is not None:
309
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
310
+ query_states, key_states, value_states, attention_mask, query_length
311
+ )
312
+
313
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
314
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
315
+
316
+ attn_output_unpad = flash_attn_varlen_func(
317
+ query_states,
318
+ key_states,
319
+ value_states,
320
+ cu_seqlens_q=cu_seqlens_q,
321
+ cu_seqlens_k=cu_seqlens_k,
322
+ max_seqlen_q=max_seqlen_in_batch_q,
323
+ max_seqlen_k=max_seqlen_in_batch_k,
324
+ dropout_p=dropout,
325
+ softmax_scale=None,
326
+ causal=causal,
327
+ )
328
+
329
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
330
+ else:
331
+ attn_output = flash_attn_func(
332
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
333
+ )
334
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
335
+ return attn_output
336
+
337
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
338
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
339
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
340
+
341
+ key_layer = index_first_axis(
342
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
343
+ )
344
+ value_layer = index_first_axis(
345
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
346
+ )
347
+ if query_length == kv_seq_len:
348
+ query_layer = index_first_axis(
349
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
350
+ indices_k
351
+ )
352
+ cu_seqlens_q = cu_seqlens_k
353
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
354
+ indices_q = indices_k
355
+ elif query_length == 1:
356
+ max_seqlen_in_batch_q = 1
357
+ cu_seqlens_q = torch.arange(
358
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
359
+ ) # There is a memcpy here, that is very bad.
360
+ indices_q = cu_seqlens_q[:-1]
361
+ query_layer = query_layer.squeeze(1)
362
+ else:
363
+ # The -q_len: slice assumes left padding.
364
+ attention_mask = attention_mask[:, -query_length:]
365
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
366
+
367
+ return (
368
+ query_layer,
369
+ key_layer,
370
+ value_layer,
371
+ indices_q,
372
+ (cu_seqlens_q, cu_seqlens_k),
373
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
374
+ )
375
+
376
+
377
+ CORE_ATTENTION_CLASSES = {
378
+ "eager": CoreAttention,
379
+ "sdpa": SdpaAttention,
380
+ "flash_attention_2": FlashAttention2
381
+ }
382
+
383
+
384
  class SelfAttention(torch.nn.Module):
385
  """Parallel self-attention layer abstract class.
 
386
  Self-attention layer takes input with size [s, b, h]
387
  and returns output of the same size.
388
  """
 
409
  device=device, **_config_to_kwargs(config)
410
  )
411
 
412
+ self.core_attention = CORE_ATTENTION_CLASSES[config._attn_implementation](config, self.layer_number)
413
 
414
  # Output.
415
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
 
433
  def forward(
434
  self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
435
  ):
436
+ # hidden_states: [b, sq, h]
437
 
438
  # =================================================
439
  # Pre-allocate memory for key-values for inference.
 
442
  # Query, Key, and Value
443
  # =====================
444
 
445
+ # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)]
446
  mixed_x_layer = self.query_key_value(hidden_states)
447
 
448
  if self.multi_query_attention:
 
470
  3 * self.hidden_size_per_attention_head)
471
  mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
472
 
473
+ # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
474
  (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
475
 
476
+ # [b, sq, np, hn] -> [b, np, sq, hn]
477
+ query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
478
+
479
  # apply relative positional encoding (rotary embedding)
480
  if rotary_pos_emb is not None:
481
  query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
482
  key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
483
 
484
  # adjust key and value for inference
485
+ if kv_cache is not None:
486
+ cache_k, cache_v = kv_cache
487
+ key_layer = torch.cat((cache_k, key_layer), dim=2)
488
+ value_layer = torch.cat((cache_v, value_layer), dim=2)
489
  if use_cache:
490
+ if kv_cache is None:
491
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
492
+ dim=1)
493
+ else:
494
+ kv_cache = (key_layer, value_layer)
495
  else:
496
  kv_cache = None
497
+
 
498
  if self.multi_query_attention:
499
+ key_layer = key_layer.unsqueeze(2)
500
  key_layer = key_layer.expand(
501
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
502
  )
503
  key_layer = key_layer.contiguous().view(
504
+ key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]
505
  )
506
+ value_layer = value_layer.unsqueeze(2)
507
  value_layer = value_layer.expand(
508
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
509
  )
510
  value_layer = value_layer.contiguous().view(
511
+ value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]
512
  )
513
 
514
  # ==================================
 
535
 
536
  class MLP(torch.nn.Module):
537
  """MLP.
 
538
  MLP will take the input with h hidden state, project it to 4*h
539
  hidden dimension, perform nonlinear transformation, and project the
540
  state back into h hidden dimension.
 
580
 
581
  class GLMBlock(torch.nn.Module):
582
  """A single transformer layer.
 
583
  Transformer layer takes input with size [s, b, h] and returns an
584
  output of the same size.
585
  """
 
690
  presents = () if use_cache else None
691
  if self.gradient_checkpointing and self.training:
692
  if use_cache:
693
+ logger.warning_once(
694
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
695
+ )
696
  use_cache = False
697
 
698
  all_self_attentions = None
 
722
  )
723
  hidden_states, kv_cache = layer_ret
724
  if use_cache:
725
+ # token by token decoding, use tuple format
726
+ if kv_caches[0] is not None:
727
+ presents = presents + (kv_cache,)
728
+ # prefilling in decoding, use tensor format to save cuda memory
729
+ else:
730
+ if len(presents) == 0:
731
+ presents = kv_cache
732
+ else:
733
+ presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
734
 
735
  if output_hidden_states:
736
  all_hidden_states = all_hidden_states + (hidden_states,)
 
753
  config_class = ChatGLMConfig
754
  base_model_prefix = "transformer"
755
  _no_split_modules = ["GLMBlock"]
756
+ _supports_flash_attn_2 = True
757
+ _supports_sdpa = True
758
 
759
  def _init_weights(self, module: nn.Module):
760
  """Initialize the weights."""
761
  return
762
 
763
  def get_masks(self, input_ids, past_key_values, padding_mask=None):
764
+ if self.config._attn_implementation == "flash_attention_2":
765
+ if padding_mask is not None and not padding_mask.all():
766
+ return padding_mask
767
+ return None
768
  batch_size, seq_length = input_ids.shape
769
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
770
  full_attention_mask.tril_()
771
  past_length = 0
772
  if past_key_values:
773
+ past_length = past_key_values[0][0].shape[2]
774
  if past_length:
775
  full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
776
  device=input_ids.device), full_attention_mask), dim=-1)
 
787
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
788
  return position_ids
789
 
 
 
 
 
 
790
  class Embedding(torch.nn.Module):
791
  """Language model embeddings."""
792
 
 
807
  # Embeddings.
808
  words_embeddings = self.word_embeddings(input_ids)
809
  embeddings = words_embeddings
 
 
810
  # If the input flag for fp32 residual connection is set, convert for float.
811
  if self.fp32_residual_connection:
812
  embeddings = embeddings.float()
 
824
  if device is not None:
825
  init_kwargs["device"] = device
826
  self.embedding = init_method(Embedding, config, **init_kwargs)
827
+ self.num_layers = config.num_layers
828
+ self.multi_query_group_num = config.multi_query_group_num
829
+ self.kv_channels = config.kv_channels
830
 
831
  # Rotary positional embeddings
832
  self.seq_length = config.seq_length
 
834
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
835
  )
836
 
837
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
838
+ original_impl=config.original_rope,
839
  device=device, dtype=config.torch_dtype)
840
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
841
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
 
844
  def get_input_embeddings(self):
845
  return self.embedding.word_embeddings
846
 
847
+ def set_input_embeddings(self, value):
848
+ self.embedding.word_embeddings = value
849
+
850
  def forward(
851
  self,
852
  input_ids,
 
856
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
857
  inputs_embeds: Optional[torch.Tensor] = None,
858
  use_cache: Optional[bool] = None,
859
+ output_attentions: Optional[bool] = None,
860
  output_hidden_states: Optional[bool] = None,
861
  return_dict: Optional[bool] = None,
862
  ):
 
871
  if inputs_embeds is None:
872
  inputs_embeds = self.embedding(input_ids)
873
 
874
+ if full_attention_mask is None:
875
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
876
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
877
 
878
  # Rotary positional embeddings
879
  rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
 
881
  rotary_pos_emb = rotary_pos_emb[position_ids]
882
  else:
883
  rotary_pos_emb = rotary_pos_emb[None, :seq_length]
 
884
 
885
  # Run encoder.
886
  hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
887
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
888
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
889
  )
890
+ if presents is not None and type(presents) is torch.Tensor:
891
+ presents = presents.split(1, dim=0)
892
+ presents = list(presents)
893
+ presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
894
+ presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
895
+ presents = tuple(presents)
896
 
897
  if not return_dict:
898
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
912
  self.max_sequence_length = config.max_length
913
  self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
914
  self.config = config
 
915
 
916
  def _update_model_kwargs_for_generation(
917
  self,
 
948
  past_key_values: Optional[torch.Tensor] = None,
949
  attention_mask: Optional[torch.Tensor] = None,
950
  position_ids: Optional[torch.Tensor] = None,
951
+ use_cache: Optional[bool] = None,
952
  is_first_forward: bool = True,
953
  **kwargs
954
  ) -> dict:
 
956
  if position_ids is None:
957
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
958
  if not is_first_forward:
959
+ if past_key_values is not None:
960
+ position_ids = position_ids[..., -1:]
961
+ input_ids = input_ids[:, -1:]
962
  return {
963
  "input_ids": input_ids,
964
  "past_key_values": past_key_values,
965
  "position_ids": position_ids,
966
  "attention_mask": attention_mask,
967
+ "return_last_logit": True,
968
+ "use_cache": use_cache
969
  }
970
 
971
  def forward(
 
975
  attention_mask: Optional[torch.Tensor] = None,
976
  past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
977
  inputs_embeds: Optional[torch.Tensor] = None,
978
+ labels: Optional[torch.Tensor] = None,
979
  use_cache: Optional[bool] = None,
980
  output_attentions: Optional[bool] = None,
981
  output_hidden_states: Optional[bool] = None,
 
998
 
999
  hidden_states = transformer_outputs[0]
1000
  if return_last_logit:
1001
+ hidden_states = hidden_states[:, -1:]
1002
  lm_logits = self.transformer.output_layer(hidden_states)
 
1003
 
1004
  loss = None
1005
  if labels is not None:
1006
  lm_logits = lm_logits.to(torch.float32)
1007
+
1008
  # Shift so that tokens < n predict n
1009
  shift_logits = lm_logits[..., :-1, :].contiguous()
 
 
1010
  shift_labels = labels[..., 1:].contiguous()
1011
+ # Flatten the tokens
1012
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1013
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 
 
 
 
 
 
 
 
 
1014
 
1015
  lm_logits = lm_logits.to(hidden_states.dtype)
1016
  loss = loss.to(hidden_states.dtype)
 
1035
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1036
  [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1037
  beam_idx at every generation step.
 
1038
  Output shares the same memory storage as `past`.
1039
  """
1040
  return tuple(
1041
  (
1042
+ layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
1043
+ layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
1044
  )
1045
  for layer_past in past
1046
  )
1047
 
 
 
 
 
 
1048
  @torch.inference_mode()
1049
  def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1050
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8,
1051
  **kwargs):
1052
  if history is None:
1053
  history = []
 
 
 
1054
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1055
+ "temperature": temperature, **kwargs}
1056
  inputs = tokenizer.build_chat_input(query, history=history, role=role)
1057
  inputs = inputs.to(self.device)
1058
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
 
1061
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1062
  response = tokenizer.decode(outputs)
1063
  history.append({"role": role, "content": query})
 
1064
  return response, history