Update modeling_opt.py
Browse files- modeling_opt.py +4 -3
modeling_opt.py
CHANGED
@@ -38,6 +38,7 @@ from transformers.utils import (
|
|
38 |
add_code_sample_docstrings,
|
39 |
add_start_docstrings,
|
40 |
add_start_docstrings_to_model_forward,
|
|
|
41 |
is_flash_attn_2_available,
|
42 |
is_flash_attn_greater_or_equal_2_10,
|
43 |
logging,
|
@@ -725,10 +726,10 @@ class OPTDecoderLayer(nn.Module):
|
|
725 |
super().__init__()
|
726 |
self.embed_dim = config.hidden_size
|
727 |
|
728 |
-
self.self_attn = OPT_ATTENTION_CLASSES[config.
|
729 |
config=config, is_decoder=True)
|
730 |
print(self.self_attn)
|
731 |
-
print(config.
|
732 |
self.do_layer_norm_before = config.do_layer_norm_before
|
733 |
self.dropout = config.dropout
|
734 |
self.activation_fn = ACT2FN[config.activation_function]
|
@@ -970,7 +971,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
|
970 |
|
971 |
self.layers = nn.ModuleList(
|
972 |
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
973 |
-
self._use_flash_attention_2 = config.
|
974 |
|
975 |
self.gradient_checkpointing = False
|
976 |
# Initialize weights and apply final processing
|
|
|
38 |
add_code_sample_docstrings,
|
39 |
add_start_docstrings,
|
40 |
add_start_docstrings_to_model_forward,
|
41 |
+
|
42 |
is_flash_attn_2_available,
|
43 |
is_flash_attn_greater_or_equal_2_10,
|
44 |
logging,
|
|
|
726 |
super().__init__()
|
727 |
self.embed_dim = config.hidden_size
|
728 |
|
729 |
+
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](
|
730 |
config=config, is_decoder=True)
|
731 |
print(self.self_attn)
|
732 |
+
print(config._attn_implementation)
|
733 |
self.do_layer_norm_before = config.do_layer_norm_before
|
734 |
self.dropout = config.dropout
|
735 |
self.activation_fn = ACT2FN[config.activation_function]
|
|
|
971 |
|
972 |
self.layers = nn.ModuleList(
|
973 |
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
974 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
975 |
|
976 |
self.gradient_checkpointing = False
|
977 |
# Initialize weights and apply final processing
|