Commit
·
ec80a0a
1
Parent(s):
8f6e343
update modeling_baichuan.py for torchscript mode with past_kv
Browse filesto enable model inference with use_cache and return_dict from model.config.
- modeling_baichuan.py +4 -2
modeling_baichuan.py
CHANGED
|
@@ -365,7 +365,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
| 365 |
use_cache: Optional[bool] = False,
|
| 366 |
output_attentions: Optional[bool] = False,
|
| 367 |
output_hidden_states: Optional[bool] = False,
|
| 368 |
-
return_dict: Optional[bool] =
|
| 369 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 370 |
if input_ids is not None and inputs_embeds is not None:
|
| 371 |
raise ValueError(
|
|
@@ -378,6 +378,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
| 378 |
else:
|
| 379 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
| 380 |
|
|
|
|
|
|
|
| 381 |
return_dict = (
|
| 382 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 383 |
)
|
|
@@ -682,7 +684,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
| 682 |
use_cache: Optional[bool] = None,
|
| 683 |
output_attentions: Optional[bool] = False,
|
| 684 |
output_hidden_states: Optional[bool] = False,
|
| 685 |
-
return_dict: Optional[bool] =
|
| 686 |
**kwargs,
|
| 687 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 688 |
return_dict = (
|
|
|
|
| 365 |
use_cache: Optional[bool] = False,
|
| 366 |
output_attentions: Optional[bool] = False,
|
| 367 |
output_hidden_states: Optional[bool] = False,
|
| 368 |
+
return_dict: Optional[bool] = None,
|
| 369 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 370 |
if input_ids is not None and inputs_embeds is not None:
|
| 371 |
raise ValueError(
|
|
|
|
| 378 |
else:
|
| 379 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
| 380 |
|
| 381 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 382 |
+
|
| 383 |
return_dict = (
|
| 384 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 385 |
)
|
|
|
|
| 684 |
use_cache: Optional[bool] = None,
|
| 685 |
output_attentions: Optional[bool] = False,
|
| 686 |
output_hidden_states: Optional[bool] = False,
|
| 687 |
+
return_dict: Optional[bool] = None,
|
| 688 |
**kwargs,
|
| 689 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 690 |
return_dict = (
|