Tom Aarsen commited on
Commit
0d894e3
·
1 Parent(s): 66ea5d6

Allow loading via AutoModelForSequenceClassification

Browse files
Files changed (1) hide show
  1. bert_layers.py +9 -0
bert_layers.py CHANGED
@@ -29,6 +29,7 @@ from transformers.modeling_outputs import (MaskedLMOutput,
29
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
30
 
31
  from .blockdiag_linear import BlockdiagLinear
 
32
  from .monarch_mixer_sequence_mixer import MonarchMixerSequenceMixing
33
 
34
  logger = logging.getLogger(__name__)
@@ -475,6 +476,8 @@ class BertModel(BertPreTrainedModel):
475
  ```
476
  """
477
 
 
 
478
  def __init__(self, config, add_pooling_layer=True):
479
  super(BertModel, self).__init__(config)
480
  self.embeddings = BertEmbeddings(config)
@@ -602,6 +605,8 @@ class BertOnlyNSPHead(nn.Module):
602
  #######################
603
  class BertForMaskedLM(BertPreTrainedModel):
604
 
 
 
605
  def __init__(self, config):
606
  super().__init__(config)
607
 
@@ -748,6 +753,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
748
  e.g., GLUE tasks.
749
  """
750
 
 
 
751
  def __init__(self, config):
752
  super().__init__(config)
753
  self.num_labels = config.num_labels
@@ -873,6 +880,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
873
 
874
  class BertForTextEncoding(BertPreTrainedModel):
875
 
 
 
876
  def __init__(self, config):
877
  super().__init__(config)
878
 
 
29
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
30
 
31
  from .blockdiag_linear import BlockdiagLinear
32
+ from .configuration_bert import BertConfig
33
  from .monarch_mixer_sequence_mixer import MonarchMixerSequenceMixing
34
 
35
  logger = logging.getLogger(__name__)
 
476
  ```
477
  """
478
 
479
+ config_class = BertConfig
480
+
481
  def __init__(self, config, add_pooling_layer=True):
482
  super(BertModel, self).__init__(config)
483
  self.embeddings = BertEmbeddings(config)
 
605
  #######################
606
  class BertForMaskedLM(BertPreTrainedModel):
607
 
608
+ config_class = BertConfig
609
+
610
  def __init__(self, config):
611
  super().__init__(config)
612
 
 
753
  e.g., GLUE tasks.
754
  """
755
 
756
+ config_class = BertConfig
757
+
758
  def __init__(self, config):
759
  super().__init__(config)
760
  self.num_labels = config.num_labels
 
880
 
881
  class BertForTextEncoding(BertPreTrainedModel):
882
 
883
+ config_class = BertConfig
884
+
885
  def __init__(self, config):
886
  super().__init__(config)
887