yangwang825 commited on
Commit
88ce039
·
1 Parent(s): 48a4d29

Upload BertForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. modeling_bert.py +6 -11
  3. pytorch_model.bin +1 -1
config.json CHANGED
@@ -1,7 +1,11 @@
1
  {
 
 
 
2
  "attention_probs_dropout_prob": 0.1,
3
  "auto_map": {
4
- "AutoConfig": "configuration_bert.BertConfig"
 
5
  },
6
  "classifier_dropout": null,
7
  "hidden_act": "gelu",
@@ -16,6 +20,7 @@
16
  "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
 
19
  "transformers_version": "4.33.3",
20
  "type_vocab_size": 2,
21
  "use_cache": true,
 
1
  {
2
+ "architectures": [
3
+ "BertForSequenceClassification"
4
+ ],
5
  "attention_probs_dropout_prob": 0.1,
6
  "auto_map": {
7
+ "AutoConfig": "configuration_bert.BertConfig",
8
+ "AutoModelForSequenceClassification": "modeling_bert.BertForSequenceClassification"
9
  },
10
  "classifier_dropout": null,
11
  "hidden_act": "gelu",
 
20
  "num_hidden_layers": 12,
21
  "pad_token_id": 0,
22
  "position_embedding_type": "absolute",
23
+ "torch_dtype": "float32",
24
  "transformers_version": "4.33.3",
25
  "type_vocab_size": 2,
26
  "use_cache": true,
modeling_bert.py CHANGED
@@ -19,16 +19,12 @@ from transformers.modeling_outputs import (
19
  SequenceClassifierOutput
20
  )
21
 
22
- from .configuration_bert import BertClsConfig
23
 
24
 
25
  class BertPreTrainedModel(PreTrainedModel):
26
- """
27
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
28
- models.
29
- """
30
 
31
- config_class = BertClsConfig
32
  load_tf_weights = load_tf_weights_in_bert
33
  base_model_prefix = "bert"
34
  supports_gradient_checkpointing = True
@@ -50,7 +46,7 @@ class BertPreTrainedModel(PreTrainedModel):
50
  module.weight.data.fill_(1.0)
51
 
52
 
53
- class BertClsPooler(nn.Module):
54
 
55
  def __init__(self, config):
56
  super().__init__()
@@ -68,7 +64,7 @@ class BertClsPooler(nn.Module):
68
 
69
  class BertModel(BertPreTrainedModel):
70
 
71
- config_class = BertClsConfig
72
 
73
  def __init__(self, config, add_pooling_layer=True):
74
  super().__init__(config)
@@ -77,7 +73,7 @@ class BertModel(BertPreTrainedModel):
77
  self.embeddings = BertEmbeddings(config)
78
  self.encoder = BertEncoder(config)
79
 
80
- self.pooler = BertClsPooler(config) if add_pooling_layer else None
81
 
82
  # Initialize weights and apply final processing
83
  self.post_init()
@@ -201,7 +197,7 @@ class BertModel(BertPreTrainedModel):
201
 
202
  class BertForSequenceClassification(BertPreTrainedModel):
203
 
204
- config_class = BertClsConfig
205
 
206
  def __init__(self, config):
207
  super().__init__(config)
@@ -290,4 +286,3 @@ class BertForSequenceClassification(BertPreTrainedModel):
290
  hidden_states=outputs.hidden_states,
291
  attentions=outputs.attentions,
292
  )
293
-
 
19
  SequenceClassifierOutput
20
  )
21
 
22
+ from .configuration_bert import BertConfig
23
 
24
 
25
  class BertPreTrainedModel(PreTrainedModel):
 
 
 
 
26
 
27
+ config_class = BertConfig
28
  load_tf_weights = load_tf_weights_in_bert
29
  base_model_prefix = "bert"
30
  supports_gradient_checkpointing = True
 
46
  module.weight.data.fill_(1.0)
47
 
48
 
49
+ class BertPooler(nn.Module):
50
 
51
  def __init__(self, config):
52
  super().__init__()
 
64
 
65
  class BertModel(BertPreTrainedModel):
66
 
67
+ config_class = BertConfig
68
 
69
  def __init__(self, config, add_pooling_layer=True):
70
  super().__init__(config)
 
73
  self.embeddings = BertEmbeddings(config)
74
  self.encoder = BertEncoder(config)
75
 
76
+ self.pooler = BertPooler(config) if add_pooling_layer else None
77
 
78
  # Initialize weights and apply final processing
79
  self.post_init()
 
197
 
198
  class BertForSequenceClassification(BertPreTrainedModel):
199
 
200
+ config_class = BertConfig
201
 
202
  def __init__(self, config):
203
  super().__init__(config)
 
286
  hidden_states=outputs.hidden_states,
287
  attentions=outputs.attentions,
288
  )
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6302c5d80ac329f1276bc9de48a4d43959ed0a4e84b7b97ef722792fe825652f
3
  size 438000689
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:601a8071a8e164093f8cbf0ed22b304427f5feff9f93aee5963fc4081a735fe5
3
  size 438000689