katuni4ka commited on
Commit
4435425
1 Parent(s): abce7c2

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +7 -1
modeling_chatglm.py CHANGED
@@ -42,6 +42,7 @@ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
42
  _CONFIG_FOR_DOC = "ChatGLMConfig"
43
 
44
  is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
 
45
 
46
 
47
  def default_init(cls, *args, **kwargs):
@@ -812,8 +813,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
812
  is_encoder_decoder: bool = False,
813
  standardize_cache_format: bool = False,
814
  ) -> Dict[str, Any]:
 
 
 
 
 
815
  # update past_key_values
816
- if is_transformers_4_42_or_higher:
817
  model_kwargs["past_key_values"] = self._extract_past_from_model_output(
818
  outputs, standardize_cache_format=standardize_cache_format
819
  )[1]
 
42
  _CONFIG_FOR_DOC = "ChatGLMConfig"
43
 
44
  is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
45
+ is_transformers_4_44_or_higher = int(transformers.__version__.split(".")[1]) >= 44
46
 
47
 
48
  def default_init(cls, *args, **kwargs):
 
813
  is_encoder_decoder: bool = False,
814
  standardize_cache_format: bool = False,
815
  ) -> Dict[str, Any]:
816
+
817
+ if is_transformers_4_44_or_higher:
818
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
819
+ outputs
820
+ )[1]
821
  # update past_key_values
822
+ elif is_transformers_4_42_or_higher:
823
  model_kwargs["past_key_values"] = self._extract_past_from_model_output(
824
  outputs, standardize_cache_format=standardize_cache_format
825
  )[1]