Chengxu Zhuang
commited on
Commit
•
272b3c6
1
Parent(s):
ee546e1
minor fix for causal mask
Browse files- modeling_flamingo.py +14 -3
modeling_flamingo.py
CHANGED
@@ -14,6 +14,12 @@ import transformers.models.opt.modeling_opt as modeling_opt
|
|
14 |
from transformers.models.opt.modeling_opt\
|
15 |
import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
|
16 |
from transformers import ViTModel
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
|
18 |
from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
|
19 |
from .configuration_flamingo import FlamingoConfig
|
@@ -232,9 +238,14 @@ class OPTDecoder(modeling_opt.OPTDecoder):
|
|
232 |
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
233 |
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
234 |
|
235 |
-
|
236 |
-
attention_mask
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
if self.project_in is not None:
|
240 |
inputs_embeds = self.project_in(inputs_embeds)
|
|
|
14 |
from transformers.models.opt.modeling_opt\
|
15 |
import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
|
16 |
from transformers import ViTModel
|
17 |
+
|
18 |
+
try:
|
19 |
+
from transformers.models.opt.modeling_opt import _prepare_4d_causal_attention_mask
|
20 |
+
except:
|
21 |
+
_prepare_4d_causal_attention_mask = None
|
22 |
+
|
23 |
from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
|
24 |
from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
|
25 |
from .configuration_flamingo import FlamingoConfig
|
|
|
238 |
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
239 |
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
240 |
|
241 |
+
if _prepare_4d_causal_attention_mask is None:
|
242 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
243 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
247 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
248 |
+
)
|
249 |
|
250 |
if self.project_in is not None:
|
251 |
inputs_embeds = self.project_in(inputs_embeds)
|