Upload modeling_chatglm.py

#2
by bigmoyan - opened
Files changed (1) hide show
  1. modeling_chatglm.py +16 -3
modeling_chatglm.py CHANGED
@@ -827,7 +827,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
827
  init_method = default_init
828
  init_kwargs = {}
829
  if device is not None:
830
- init_kwargs["device"] = device
 
 
831
  self.embedding = init_method(Embedding, config, **init_kwargs)
832
  self.num_layers = config.num_layers
833
  self.multi_query_group_num = config.multi_query_group_num
@@ -923,10 +925,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
923
  outputs: ModelOutput,
924
  model_kwargs: Dict[str, Any],
925
  is_encoder_decoder: bool = False,
 
926
  ) -> Dict[str, Any]:
927
  # update past_key_values
928
- cache_name, cache = self._extract_past_from_model_output(outputs)
929
- model_kwargs[cache_name] = cache
 
 
 
 
 
 
930
 
931
  # update attention mask
932
  if "attention_mask" in model_kwargs:
@@ -945,6 +954,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
945
  )
946
 
947
  model_kwargs["is_first_forward"] = False
 
 
 
 
948
  return model_kwargs
949
 
950
  def prepare_inputs_for_generation(
 
827
  init_method = default_init
828
  init_kwargs = {}
829
  if device is not None:
830
+ init_kwargs["device"] = torch.device(device)
831
+ if isinstance(config.torch_dtype, str):
832
+ config.torch_dtype = getattr(torch, config.torch_dtype)
833
  self.embedding = init_method(Embedding, config, **init_kwargs)
834
  self.num_layers = config.num_layers
835
  self.multi_query_group_num = config.multi_query_group_num
 
925
  outputs: ModelOutput,
926
  model_kwargs: Dict[str, Any],
927
  is_encoder_decoder: bool = False,
928
+ num_new_tokens: int = 1,
929
  ) -> Dict[str, Any]:
930
  # update past_key_values
931
+ for possible_cache_name in ["past_key_values", "mems", "past_buckets_states", "cache_params"]:
932
+ if hasattr(outputs, possible_cache_name):
933
+ if possible_cache_name in ("past_buckets_states", "mems"):
934
+ cache_name = "past_key_values"
935
+ else:
936
+ cache_name = possible_cache_name
937
+ model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
938
+ break
939
 
940
  # update attention mask
941
  if "attention_mask" in model_kwargs:
 
954
  )
955
 
956
  model_kwargs["is_first_forward"] = False
957
+
958
+ if model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs:
959
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
960
+
961
  return model_kwargs
962
 
963
  def prepare_inputs_for_generation(