zR commited on
Commit
12372fe
1 Parent(s): f5ce116

transformer 4.42.4 upgrade

Browse files
Files changed (4) hide show
  1. README.md +2 -0
  2. README_en.md +2 -0
  3. config.json +1 -1
  4. modeling_chatglm.py +6 -205
README.md CHANGED
@@ -76,6 +76,8 @@ GLM-4-9B 是智谱 AI 推出的最新一代预训练模型 GLM-4 系列中的开
76
 
77
  ### 使用 transformers 后端进行推理:
78
 
 
 
79
  ```python
80
  import torch
81
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
76
 
77
  ### 使用 transformers 后端进行推理:
78
 
79
+ **请使用 transformers == 4.42.4**
80
+
81
  ```python
82
  import torch
83
  from transformers import AutoModelForCausalLM, AutoTokenizer
README_en.md CHANGED
@@ -70,6 +70,8 @@ For more inference code and requirements, please visit our [github page](https:/
70
 
71
  ### Use the following method to quickly call the GLM-4-9B-Chat language model
72
 
 
 
73
  Use the transformers backend for inference:
74
 
75
  ```python
 
70
 
71
  ### Use the following method to quickly call the GLM-4-9B-Chat language model
72
 
73
+ **Please use transformers == 4.42.4**
74
+
75
  Use the transformers backend for inference:
76
 
77
  ```python
config.json CHANGED
@@ -38,7 +38,7 @@
38
  "seq_length": 131072,
39
  "use_cache": true,
40
  "torch_dtype": "bfloat16",
41
- "transformers_version": "4.40.2",
42
  "tie_word_embeddings": false,
43
  "eos_token_id": [151329, 151336, 151338],
44
  "pad_token_id": 151329
 
38
  "seq_length": 131072,
39
  "use_cache": true,
40
  "torch_dtype": "bfloat16",
41
+ "transformers_version": "4.42.4",
42
  "tie_word_embeddings": false,
43
  "eos_token_id": [151329, 151336, 151338],
44
  "pad_token_id": 151329
modeling_chatglm.py CHANGED
@@ -1,19 +1,14 @@
1
  """ PyTorch ChatGLM model. """
2
- import json
3
  import math
4
- import copy
5
- import warnings
6
- import re
7
  import sys
8
-
9
  import torch
10
  import torch.utils.checkpoint
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
14
  from torch.nn.utils import skip_init
15
- from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
- from copy import deepcopy
17
 
18
  from transformers.modeling_outputs import (
19
  BaseModelOutputWithPast,
@@ -23,19 +18,19 @@ from transformers.modeling_outputs import (
23
  from transformers.modeling_utils import PreTrainedModel
24
  from transformers.utils import logging, is_torch_npu_available
25
  from transformers.generation.logits_process import LogitsProcessor
26
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
28
  from .configuration_chatglm import ChatGLMConfig
29
 
30
  try:
31
  from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
 
32
  if is_flash_attn_2_available():
33
  from flash_attn import flash_attn_func, flash_attn_varlen_func
34
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
  except:
36
  pass
37
 
38
-
39
  # flags required to enable jit fusion kernels
40
 
41
  if sys.platform != 'darwin' and not is_torch_npu_available():
@@ -354,7 +349,8 @@ class FlashAttention2(CoreAttention):
354
  )
355
  if query_length == kv_seq_len:
356
  query_layer = index_first_axis(
357
- query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), indices_k
 
358
  )
359
  cu_seqlens_q = cu_seqlens_k
360
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -1064,201 +1060,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1064
  for layer_past in past
1065
  )
1066
 
1067
- def process_response(self, output, history):
1068
- content = ""
1069
- history = deepcopy(history)
1070
- for response in output.split("<|assistant|>"):
1071
- if "\n" in response:
1072
- metadata, content = response.split("\n", maxsplit=1)
1073
- else:
1074
- metadata, content = "", response
1075
- if not metadata.strip():
1076
- content = content.strip()
1077
- history.append({"role": "assistant", "metadata": metadata, "content": content})
1078
- content = content.replace("[[训练时间]]", "2023年")
1079
- else:
1080
- history.append({"role": "assistant", "metadata": metadata, "content": content})
1081
- if history[0]["role"] == "system" and "tools" in history[0]:
1082
- parameters = json.loads(content)
1083
- content = {"name": metadata.strip(), "parameters": parameters}
1084
- else:
1085
- content = {"name": metadata.strip(), "content": content}
1086
- return content, history
1087
-
1088
- @torch.inference_mode()
1089
- def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1090
- max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1091
- **kwargs):
1092
- if history is None:
1093
- history = []
1094
- if logits_processor is None:
1095
- logits_processor = LogitsProcessorList()
1096
- logits_processor.append(InvalidScoreLogitsProcessor())
1097
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1098
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1099
- history.append({"role": role, "content": query})
1100
- inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,
1101
- return_tensors="pt", return_dict=True)
1102
- inputs = inputs.to(self.device)
1103
- eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
1104
- tokenizer.convert_tokens_to_ids("<|observation|>")]
1105
- outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1106
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1107
- response = tokenizer.decode(outputs)
1108
- response, history = self.process_response(response, history)
1109
- return response, history
1110
-
1111
- @torch.inference_mode()
1112
- def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1113
- past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1114
- logits_processor=None, return_past_key_values=False, **kwargs):
1115
- if history is None:
1116
- history = []
1117
- if logits_processor is None:
1118
- logits_processor = LogitsProcessorList()
1119
- logits_processor.append(InvalidScoreLogitsProcessor())
1120
- eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
1121
- tokenizer.convert_tokens_to_ids("<|observation|>")]
1122
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1123
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1124
- if past_key_values is None:
1125
- inputs = tokenizer.apply_chat_template(history + [{"role": role, "content": query}],
1126
- add_generation_prompt=True, tokenize=True, return_tensors="pt",
1127
- return_dict=True)
1128
- else:
1129
- inputs = tokenizer.apply_chat_template([{"role": role, "content": query}], add_special_tokens=False,
1130
- add_generation_prompt=True, tokenize=True, return_tensors="pt",
1131
- return_dict=True)
1132
- inputs = inputs.to(self.device)
1133
- if past_key_values is not None:
1134
- past_length = past_key_values[0][0].shape[2]
1135
- inputs.position_ids += past_length
1136
- attention_mask = inputs.attention_mask
1137
- attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1138
- inputs['attention_mask'] = attention_mask
1139
- history.append({"role": role, "content": query})
1140
- for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1141
- eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1142
- **gen_kwargs):
1143
- if return_past_key_values:
1144
- outputs, past_key_values = outputs
1145
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1146
- response = tokenizer.decode(outputs)
1147
- if response and response[-1] != "�":
1148
- response, new_history = self.process_response(response, history)
1149
- if return_past_key_values:
1150
- yield response, new_history, past_key_values
1151
- else:
1152
- yield response, new_history
1153
-
1154
- @torch.inference_mode()
1155
- def stream_generate(
1156
- self,
1157
- input_ids,
1158
- generation_config: Optional[GenerationConfig] = None,
1159
- logits_processor: Optional[LogitsProcessorList] = None,
1160
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1161
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1162
- return_past_key_values=False,
1163
- **kwargs,
1164
- ):
1165
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1166
-
1167
- if generation_config is None:
1168
- generation_config = self.generation_config
1169
- generation_config = copy.deepcopy(generation_config)
1170
- model_kwargs = generation_config.update(**kwargs)
1171
- model_kwargs["use_cache"] = generation_config.use_cache
1172
- bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1173
-
1174
- if isinstance(eos_token_id, int):
1175
- eos_token_id = [eos_token_id]
1176
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1177
-
1178
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1179
- if has_default_max_length and generation_config.max_new_tokens is None:
1180
- warnings.warn(
1181
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1182
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1183
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
1184
- UserWarning,
1185
- )
1186
- elif generation_config.max_new_tokens is not None:
1187
- generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1188
- if not has_default_max_length:
1189
- logger.warn(
1190
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1191
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1192
- "Please refer to the documentation for more information. "
1193
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1194
- UserWarning,
1195
- )
1196
-
1197
- if input_ids_seq_length >= generation_config.max_length:
1198
- input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1199
- logger.warning(
1200
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1201
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1202
- " increasing `max_new_tokens`."
1203
- )
1204
-
1205
- # 2. Set generation parameters if not already defined
1206
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1207
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1208
-
1209
- logits_processor = self._get_logits_processor(
1210
- generation_config=generation_config,
1211
- input_ids_seq_length=input_ids_seq_length,
1212
- encoder_input_ids=input_ids,
1213
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1214
- logits_processor=logits_processor,
1215
- )
1216
-
1217
- stopping_criteria = self._get_stopping_criteria(
1218
- generation_config=generation_config, stopping_criteria=stopping_criteria
1219
- )
1220
- logits_warper = self._get_logits_warper(generation_config)
1221
-
1222
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1223
- scores = None
1224
- while True:
1225
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1226
- # forward pass to get next token
1227
- outputs = self(
1228
- **model_inputs,
1229
- return_dict=True,
1230
- output_attentions=False,
1231
- output_hidden_states=False,
1232
- )
1233
-
1234
- next_token_logits = outputs.logits[:, -1, :]
1235
-
1236
- # pre-process distribution
1237
- next_token_scores = logits_processor(input_ids, next_token_logits)
1238
- next_token_scores = logits_warper(input_ids, next_token_scores)
1239
-
1240
- # sample
1241
- probs = nn.functional.softmax(next_token_scores, dim=-1)
1242
- if generation_config.do_sample:
1243
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1244
- else:
1245
- next_tokens = torch.argmax(probs, dim=-1)
1246
- # update generated ids, model inputs, and length for next step
1247
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1248
- model_kwargs = self._update_model_kwargs_for_generation(
1249
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1250
- )
1251
- unfinished_sequences = unfinished_sequences.mul(
1252
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1253
- )
1254
- if return_past_key_values:
1255
- yield input_ids, outputs.past_key_values
1256
- else:
1257
- yield input_ids
1258
- # stop when each sentence is finished, or if we exceed the maximum length
1259
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1260
- break
1261
-
1262
 
1263
  class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1264
  def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
 
1
  """ PyTorch ChatGLM model. """
2
+
3
  import math
 
 
 
4
  import sys
 
5
  import torch
6
  import torch.utils.checkpoint
7
  import torch.nn.functional as F
8
  from torch import nn
9
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
10
  from torch.nn.utils import skip_init
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
 
12
 
13
  from transformers.modeling_outputs import (
14
  BaseModelOutputWithPast,
 
18
  from transformers.modeling_utils import PreTrainedModel
19
  from transformers.utils import logging, is_torch_npu_available
20
  from transformers.generation.logits_process import LogitsProcessor
21
+ from transformers.generation.utils import ModelOutput
22
 
23
  from .configuration_chatglm import ChatGLMConfig
24
 
25
  try:
26
  from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
27
+
28
  if is_flash_attn_2_available():
29
  from flash_attn import flash_attn_func, flash_attn_varlen_func
30
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
  except:
32
  pass
33
 
 
34
  # flags required to enable jit fusion kernels
35
 
36
  if sys.platform != 'darwin' and not is_torch_npu_available():
 
349
  )
350
  if query_length == kv_seq_len:
351
  query_layer = index_first_axis(
352
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
353
+ indices_k
354
  )
355
  cu_seqlens_q = cu_seqlens_k
356
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
1060
  for layer_past in past
1061
  )
1062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
 
1064
  class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1065
  def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):