hackyon commited on
Commit
b5c5b66
·
verified ·
1 Parent(s): 7afe512

Upload EncT5ForSequenceClassification

Browse files
Files changed (2) hide show
  1. model.safetensors +1 -1
  2. modeling_enct5.py +7 -2
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e9cc0194fa5bfc256b2e2d47affe664f166cdaf29430947220e1606223691cc
3
  size 476301088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e67a80a5bd78ab3885be58d157623db44c1ff78204817af457bd31b00e6b49aa
3
  size 476301088
modeling_enct5.py CHANGED
@@ -93,6 +93,7 @@ class EncT5ForSequenceClassification(EncT5PreTrainedModel):
93
 
94
  # Initiate decoder embedding from scratch and define the corresponding latent vector vocabulary size.
95
  self.decoder_embeddings = nn.Embedding(config.decoder_vocab_size, config.d_model)
 
96
 
97
  # Initiate decoder projection head from scratch.
98
  if config.problem_type == "multi_label_classification":
@@ -107,14 +108,18 @@ class EncT5ForSequenceClassification(EncT5PreTrainedModel):
107
 
108
  def load_weights_from_pretrained_t5(self, model_path: str):
109
  pretrained_t5_model = T5Model.from_pretrained(model_path)
110
- self.transformer.load_state_dict(pretrained_t5_model.state_dict(), strict=False)
 
 
 
 
 
111
 
112
  def prepare_for_fine_tuning(self):
113
  r"""
114
  Prepares the model for fine-tuning by re-initializing the necessary weights for fine-tuning. This step should be
115
  performed after loading the pre-trained T5 model but before fine-tuning.
116
  """
117
- self.transformer.get_decoder().set_input_embeddings(self.decoder_embeddings)
118
  self.transformer.get_decoder().apply(self._init_weights)
119
  self._init_weights(self.classification_head)
120
 
 
93
 
94
  # Initiate decoder embedding from scratch and define the corresponding latent vector vocabulary size.
95
  self.decoder_embeddings = nn.Embedding(config.decoder_vocab_size, config.d_model)
96
+ self.transformer.get_decoder().set_input_embeddings(self.decoder_embeddings)
97
 
98
  # Initiate decoder projection head from scratch.
99
  if config.problem_type == "multi_label_classification":
 
108
 
109
  def load_weights_from_pretrained_t5(self, model_path: str):
110
  pretrained_t5_model = T5Model.from_pretrained(model_path)
111
+
112
+ # Override the decoder embedding weights to make them the correct shape.
113
+ pretrained_state_dict = pretrained_t5_model.state_dict()
114
+ pretrained_state_dict["decoder.embed_tokens.weight"] = self.decoder_embeddings.state_dict()["weight"]
115
+
116
+ self.transformer.load_state_dict(pretrained_state_dict, strict=False)
117
 
118
  def prepare_for_fine_tuning(self):
119
  r"""
120
  Prepares the model for fine-tuning by re-initializing the necessary weights for fine-tuning. This step should be
121
  performed after loading the pre-trained T5 model but before fine-tuning.
122
  """
 
123
  self.transformer.get_decoder().apply(self._init_weights)
124
  self._init_weights(self.classification_head)
125