Chengxu Zhuang commited on
Commit
272b3c6
1 Parent(s): ee546e1

minor fix for causal mask

Browse files
Files changed (1) hide show
  1. modeling_flamingo.py +14 -3
modeling_flamingo.py CHANGED
@@ -14,6 +14,12 @@ import transformers.models.opt.modeling_opt as modeling_opt
14
  from transformers.models.opt.modeling_opt\
15
  import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
16
  from transformers import ViTModel
 
 
 
 
 
 
17
  from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
18
  from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
19
  from .configuration_flamingo import FlamingoConfig
@@ -232,9 +238,14 @@ class OPTDecoder(modeling_opt.OPTDecoder):
232
  attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
233
  pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
234
 
235
- attention_mask = self._prepare_decoder_attention_mask(
236
- attention_mask, input_shape, inputs_embeds, past_key_values_length
237
- )
 
 
 
 
 
238
 
239
  if self.project_in is not None:
240
  inputs_embeds = self.project_in(inputs_embeds)
 
14
  from transformers.models.opt.modeling_opt\
15
  import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
16
  from transformers import ViTModel
17
+
18
+ try:
19
+ from transformers.models.opt.modeling_opt import _prepare_4d_causal_attention_mask
20
+ except:
21
+ _prepare_4d_causal_attention_mask = None
22
+
23
  from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
24
  from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
25
  from .configuration_flamingo import FlamingoConfig
 
238
  attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
239
  pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
240
 
241
+ if _prepare_4d_causal_attention_mask is None:
242
+ attention_mask = self._prepare_decoder_attention_mask(
243
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
244
+ )
245
+ else:
246
+ attention_mask = _prepare_4d_causal_attention_mask(
247
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
248
+ )
249
 
250
  if self.project_in is not None:
251
  inputs_embeds = self.project_in(inputs_embeds)