joaogante HF staff commited on
Commit
29ac323
·
1 Parent(s): fbcbb1a

ForCausalLM

Browse files
Files changed (3) hide show
  1. README.md +2 -0
  2. config.json +2 -2
  3. modeling.py +3 -3
README.md CHANGED
@@ -1,3 +1,5 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ This is the same as `test_dynamic_model`, but with a `generate`-compatible class
config.json CHANGED
@@ -6,7 +6,7 @@
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
  "AutoConfig": "configuration.NewModelConfig",
9
- "AutoModel": "modeling.NewModel"
10
  },
11
  "classifier_dropout": null,
12
  "hidden_act": "gelu",
@@ -23,7 +23,7 @@
23
  "pad_token_id": 0,
24
  "position_embedding_type": "absolute",
25
  "torch_dtype": "float32",
26
- "transformers_version": "4.16.0.dev0",
27
  "type_vocab_size": 2,
28
  "use_cache": true,
29
  "vocab_size": 30522
 
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
  "AutoConfig": "configuration.NewModelConfig",
9
+ "AutoModelForCausalLM": "modeling.NewModelForCausalLM"
10
  },
11
  "classifier_dropout": null,
12
  "hidden_act": "gelu",
 
23
  "pad_token_id": 0,
24
  "position_embedding_type": "absolute",
25
  "torch_dtype": "float32",
26
+ "transformers_version": "4.45.0.dev0",
27
  "type_vocab_size": 2,
28
  "use_cache": true,
29
  "vocab_size": 30522
modeling.py CHANGED
@@ -1,11 +1,11 @@
1
  import torch
2
- from transformers import BertModel
3
 
4
  from .configuration import NewModelConfig
5
 
6
- class NewModel(BertModel):
7
  config_class = NewModelConfig
8
 
9
  def __init__(self, config):
10
  super().__init__(config)
11
- self.last_layer = torch.nn.Linear(config.hidden_size, config.new_hidden_size)
 
1
  import torch
2
+ from transformers import BertLMHeadModel
3
 
4
  from .configuration import NewModelConfig
5
 
6
+ class NewModelForCausalLM(BertLMHeadModel):
7
  config_class = NewModelConfig
8
 
9
  def __init__(self, config):
10
  super().__init__(config)
11
+ self.last_layer = torch.nn.Linear(config.hidden_size, config.new_hidden_size)