m10an commited on
Commit
550be4b
·
verified ·
1 Parent(s): 22c54fa

Update configuration_bert.py

Browse files
Files changed (1) hide show
  1. configuration_bert.py +5 -0
configuration_bert.py CHANGED
@@ -1,6 +1,7 @@
1
  # Copyright 2022 MosaicML Examples authors
2
  # SPDX-License-Identifier: Apache-2.0
3
 
 
4
  from transformers import BertConfig as TransformersBertConfig
5
 
6
 
@@ -10,6 +11,7 @@ class BertConfig(TransformersBertConfig):
10
  self,
11
  alibi_starting_size: int = 512,
12
  attention_probs_dropout_prob: float = 0.0,
 
13
  **kwargs,
14
  ):
15
  """Configuration class for MosaicBert.
@@ -20,7 +22,10 @@ class BertConfig(TransformersBertConfig):
20
  Defaults to 512.
21
  attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT
22
  (otherwise, Flash Attention will be off by default). Defaults to 0.0.
 
 
23
  """
24
  super().__init__(
25
  attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
26
  self.alibi_starting_size = alibi_starting_size
 
 
1
  # Copyright 2022 MosaicML Examples authors
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ from typing import Optional
5
  from transformers import BertConfig as TransformersBertConfig
6
 
7
 
 
11
  self,
12
  alibi_starting_size: int = 512,
13
  attention_probs_dropout_prob: float = 0.0,
14
+ flash_attn_type: Optional[str] = None,
15
  **kwargs,
16
  ):
17
  """Configuration class for MosaicBert.
 
22
  Defaults to 512.
23
  attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT
24
  (otherwise, Flash Attention will be off by default). Defaults to 0.0.
25
+ flash_attn_type (str): if 'triton' is passed will use ./flash_attn_triton.py.
26
+ Defaults to None (disabled).
27
  """
28
  super().__init__(
29
  attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
30
  self.alibi_starting_size = alibi_starting_size
31
+ self.flash_attn_type = flash_attn_type