robinzixuan commited on
Commit
b0eaf6f
·
verified ·
1 Parent(s): 7f1445a

Update modeling_opt.py

Browse files
Files changed (1) hide show
  1. modeling_opt.py +4 -3
modeling_opt.py CHANGED
@@ -38,6 +38,7 @@ from transformers.utils import (
38
  add_code_sample_docstrings,
39
  add_start_docstrings,
40
  add_start_docstrings_to_model_forward,
 
41
  is_flash_attn_2_available,
42
  is_flash_attn_greater_or_equal_2_10,
43
  logging,
@@ -725,10 +726,10 @@ class OPTDecoderLayer(nn.Module):
725
  super().__init__()
726
  self.embed_dim = config.hidden_size
727
 
728
- self.self_attn = OPT_ATTENTION_CLASSES[config.attn_implementation](
729
  config=config, is_decoder=True)
730
  print(self.self_attn)
731
- print(config.attn_implementation)
732
  self.do_layer_norm_before = config.do_layer_norm_before
733
  self.dropout = config.dropout
734
  self.activation_fn = ACT2FN[config.activation_function]
@@ -970,7 +971,7 @@ class OPTDecoder(OPTPreTrainedModel):
970
 
971
  self.layers = nn.ModuleList(
972
  [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
973
- self._use_flash_attention_2 = config.attn_implementation == "flash_attention_2"
974
 
975
  self.gradient_checkpointing = False
976
  # Initialize weights and apply final processing
 
38
  add_code_sample_docstrings,
39
  add_start_docstrings,
40
  add_start_docstrings_to_model_forward,
41
+
42
  is_flash_attn_2_available,
43
  is_flash_attn_greater_or_equal_2_10,
44
  logging,
 
726
  super().__init__()
727
  self.embed_dim = config.hidden_size
728
 
729
+ self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](
730
  config=config, is_decoder=True)
731
  print(self.self_attn)
732
+ print(config._attn_implementation)
733
  self.do_layer_norm_before = config.do_layer_norm_before
734
  self.dropout = config.dropout
735
  self.activation_fn = ACT2FN[config.activation_function]
 
971
 
972
  self.layers = nn.ModuleList(
973
  [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
974
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
975
 
976
  self.gradient_checkpointing = False
977
  # Initialize weights and apply final processing