feat: removed task type embeddings
Browse files- modeling_bert.py +1 -12
modeling_bert.py
CHANGED
|
@@ -152,7 +152,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
| 152 |
nn.init.normal_(module.weight, std=initializer_range)
|
| 153 |
if module.bias is not None:
|
| 154 |
nn.init.zeros_(module.bias)
|
| 155 |
-
elif isinstance(module, nn.Embedding)
|
| 156 |
nn.init.normal_(module.weight, std=initializer_range)
|
| 157 |
if module.padding_idx is not None:
|
| 158 |
nn.init.zeros_(module.weight[module.padding_idx])
|
|
@@ -351,7 +351,6 @@ class BertModel(BertPreTrainedModel):
|
|
| 351 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 352 |
self.encoder = BertEncoder(config)
|
| 353 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 354 |
-
self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
|
| 355 |
|
| 356 |
self.emb_pooler = config.emb_pooler
|
| 357 |
self._name_or_path = config._name_or_path
|
|
@@ -362,13 +361,6 @@ class BertModel(BertPreTrainedModel):
|
|
| 362 |
else:
|
| 363 |
self.tokenizer = None
|
| 364 |
|
| 365 |
-
# We now initialize the task embeddings to 0; We do not use task types during
|
| 366 |
-
# pretraining. When we start using task types during embedding training,
|
| 367 |
-
# we want the model to behave exactly as in pretraining (i.e. task types
|
| 368 |
-
# have no effect).
|
| 369 |
-
nn.init.zeros_(self.task_type_embeddings.weight)
|
| 370 |
-
self.task_type_embeddings.skip_init = True
|
| 371 |
-
# The following code should skip the embeddings layer
|
| 372 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 373 |
|
| 374 |
def forward(
|
|
@@ -376,7 +368,6 @@ class BertModel(BertPreTrainedModel):
|
|
| 376 |
input_ids,
|
| 377 |
position_ids=None,
|
| 378 |
token_type_ids=None,
|
| 379 |
-
task_type_ids=None,
|
| 380 |
attention_mask=None,
|
| 381 |
masked_tokens_mask=None,
|
| 382 |
return_dict=True,
|
|
@@ -389,8 +380,6 @@ class BertModel(BertPreTrainedModel):
|
|
| 389 |
hidden_states = self.embeddings(
|
| 390 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 391 |
)
|
| 392 |
-
if task_type_ids is not None:
|
| 393 |
-
hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
|
| 394 |
|
| 395 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 396 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
|
|
| 152 |
nn.init.normal_(module.weight, std=initializer_range)
|
| 153 |
if module.bias is not None:
|
| 154 |
nn.init.zeros_(module.bias)
|
| 155 |
+
elif isinstance(module, nn.Embedding):
|
| 156 |
nn.init.normal_(module.weight, std=initializer_range)
|
| 157 |
if module.padding_idx is not None:
|
| 158 |
nn.init.zeros_(module.weight[module.padding_idx])
|
|
|
|
| 351 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 352 |
self.encoder = BertEncoder(config)
|
| 353 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
|
| 354 |
|
| 355 |
self.emb_pooler = config.emb_pooler
|
| 356 |
self._name_or_path = config._name_or_path
|
|
|
|
| 361 |
else:
|
| 362 |
self.tokenizer = None
|
| 363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 365 |
|
| 366 |
def forward(
|
|
|
|
| 368 |
input_ids,
|
| 369 |
position_ids=None,
|
| 370 |
token_type_ids=None,
|
|
|
|
| 371 |
attention_mask=None,
|
| 372 |
masked_tokens_mask=None,
|
| 373 |
return_dict=True,
|
|
|
|
| 380 |
hidden_states = self.embeddings(
|
| 381 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 382 |
)
|
|
|
|
|
|
|
| 383 |
|
| 384 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 385 |
# BERT puts embedding LayerNorm before embedding dropout.
|