Update yuan_hf_model.py
Browse files- yuan_hf_model.py +4 -3
yuan_hf_model.py
CHANGED
@@ -32,8 +32,8 @@ from transformers.modeling_utils import PreTrainedModel
|
|
32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
from .configuration_yuan import YuanConfig
|
34 |
from einops import rearrange
|
35 |
-
from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
|
36 |
-
from flash_attn import flash_attn_func
|
37 |
|
38 |
import copy
|
39 |
|
@@ -268,7 +268,8 @@ class YuanAttention(nn.Module):
|
|
268 |
is_first_step = False
|
269 |
if use_cache:
|
270 |
if past_key_value is None:
|
271 |
-
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
|
|
|
272 |
is_first_step = True
|
273 |
else:
|
274 |
before_hidden_states = past_key_value[2]
|
|
|
32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
from .configuration_yuan import YuanConfig
|
34 |
from einops import rearrange
|
35 |
+
#from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
|
36 |
+
#from flash_attn import flash_attn_func
|
37 |
|
38 |
import copy
|
39 |
|
|
|
268 |
is_first_step = False
|
269 |
if use_cache:
|
270 |
if past_key_value is None:
|
271 |
+
#inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
|
272 |
+
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
|
273 |
is_first_step = True
|
274 |
else:
|
275 |
before_hidden_states = past_key_value[2]
|