Update modeling_chatglm.py
Browse files- modeling_chatglm.py +7 -1
modeling_chatglm.py
CHANGED
@@ -42,6 +42,7 @@ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
|
42 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
43 |
|
44 |
is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
|
|
|
45 |
|
46 |
|
47 |
def default_init(cls, *args, **kwargs):
|
@@ -812,8 +813,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
812 |
is_encoder_decoder: bool = False,
|
813 |
standardize_cache_format: bool = False,
|
814 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
815 |
# update past_key_values
|
816 |
-
|
817 |
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
818 |
outputs, standardize_cache_format=standardize_cache_format
|
819 |
)[1]
|
|
|
42 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
43 |
|
44 |
is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
|
45 |
+
is_transformers_4_44_or_higher = int(transformers.__version__.split(".")[1]) >= 44
|
46 |
|
47 |
|
48 |
def default_init(cls, *args, **kwargs):
|
|
|
813 |
is_encoder_decoder: bool = False,
|
814 |
standardize_cache_format: bool = False,
|
815 |
) -> Dict[str, Any]:
|
816 |
+
|
817 |
+
if is_transformers_4_44_or_higher:
|
818 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
819 |
+
outputs
|
820 |
+
)[1]
|
821 |
# update past_key_values
|
822 |
+
elif is_transformers_4_42_or_higher:
|
823 |
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
824 |
outputs, standardize_cache_format=standardize_cache_format
|
825 |
)[1]
|