|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.models.luke.modeling_luke import ( |
|
EntityPredictionHead, |
|
LukeLMHead, |
|
LukeModel, |
|
) |
|
from transformers.utils import ModelOutput |
|
|
|
from .configuration_ubke import UbkeConfig |
|
|
|
|
|
@dataclass |
|
class UbkeMaskedLMOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
mlm_loss: Optional[torch.FloatTensor] = None |
|
mep_loss: Optional[torch.FloatTensor] = None |
|
tep_loss: Optional[torch.FloatTensor] = None |
|
tcp_loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
entity_logits: Optional[torch.FloatTensor] = None |
|
topic_entity_logits: torch.FloatTensor = None |
|
topic_category_logits: Optional[torch.FloatTensor] = None |
|
last_hidden_state: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
entity_last_hidden_state: torch.FloatTensor = None |
|
entity_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
|
class UbkePreTrainedModel(PreTrainedModel): |
|
config_class = UbkeConfig |
|
base_model_prefix = "luke" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"] |
|
|
|
def _init_weights(self, module: nn.Module): |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
if module.embedding_dim == 1: |
|
module.weight.data.zero_() |
|
else: |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
class UbkeForMaskedLM(UbkePreTrainedModel): |
|
_tied_weights_keys = [ |
|
"lm_head.decoder.weight", |
|
"lm_head.decoder.bias", |
|
"entity_predictions.decoder.weight", |
|
] |
|
|
|
def __init__(self, config: UbkeConfig): |
|
super().__init__(config) |
|
|
|
self.luke = LukeModel(config) |
|
|
|
if self.config.normalize_entity_embeddings: |
|
self.luke.entity_embeddings.entity_embeddings = nn.Embedding( |
|
config.entity_vocab_size, |
|
config.entity_emb_size, |
|
padding_idx=0, |
|
max_norm=1.0, |
|
) |
|
|
|
self.lm_head = LukeLMHead(config) |
|
self.entity_predictions = EntityPredictionHead(config) |
|
|
|
self.loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
self.post_init() |
|
|
|
def tie_weights(self): |
|
super().tie_weights() |
|
self._tie_or_clone_weights( |
|
self.entity_predictions.decoder, |
|
self.luke.entity_embeddings.entity_embeddings, |
|
) |
|
|
|
def get_output_embeddings(self) -> nn.Module: |
|
return self.lm_head.decoder |
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Module): |
|
self.lm_head.decoder = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
entity_ids: Optional[torch.LongTensor] = None, |
|
entity_attention_mask: Optional[torch.LongTensor] = None, |
|
entity_token_type_ids: Optional[torch.LongTensor] = None, |
|
entity_position_ids: Optional[torch.LongTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
entity_labels: Optional[torch.LongTensor] = None, |
|
topic_entity_labels: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, UbkeMaskedLMOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
outputs = self.luke( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
entity_ids=entity_ids, |
|
entity_attention_mask=entity_attention_mask, |
|
entity_token_type_ids=entity_token_type_ids, |
|
entity_position_ids=entity_position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
) |
|
|
|
loss = None |
|
|
|
mlm_loss = None |
|
logits = self.lm_head(outputs.last_hidden_state) |
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
mlm_loss = self.loss_fn( |
|
logits.view(-1, self.config.vocab_size), labels.view(-1) |
|
) |
|
if loss is None: |
|
loss = mlm_loss |
|
|
|
mep_loss = None |
|
entity_logits = None |
|
if outputs.entity_last_hidden_state is not None: |
|
entity_logits = self.entity_predictions(outputs.entity_last_hidden_state) |
|
if entity_labels is not None: |
|
mep_loss = self.loss_fn( |
|
entity_logits.view(-1, self.config.entity_vocab_size) |
|
/ self.config.entity_temperature, |
|
entity_labels.view(-1), |
|
) |
|
if loss is None: |
|
loss = mep_loss |
|
else: |
|
loss = loss + mep_loss |
|
|
|
topic_entity_logits = self.entity_predictions(outputs.last_hidden_state[:, 0]) |
|
topic_category_logits = None |
|
if self.config.num_category_entities > 0: |
|
topic_category_logits = topic_entity_logits[ |
|
:, -self.config.num_category_entities : |
|
] |
|
topic_entity_logits = topic_entity_logits[ |
|
:, : -self.config.num_category_entities |
|
] |
|
|
|
topic_category_labels = None |
|
if topic_entity_labels is not None and self.config.num_category_entities > 0: |
|
topic_category_labels = topic_entity_labels[ |
|
:, -self.config.num_category_entities : |
|
] |
|
topic_entity_labels = topic_entity_labels[ |
|
:, : -self.config.num_category_entities |
|
] |
|
|
|
tep_loss = None |
|
if topic_entity_labels is not None: |
|
num_topic_entity_labels = topic_entity_labels.sum(dim=1) |
|
if (num_topic_entity_labels > 0).any(): |
|
topic_entity_labels = topic_entity_labels.to( |
|
topic_entity_logits.dtype |
|
) / num_topic_entity_labels.unsqueeze(-1) |
|
tep_loss = self.loss_fn( |
|
topic_entity_logits[num_topic_entity_labels > 0] |
|
/ self.config.entity_temperature, |
|
topic_entity_labels[num_topic_entity_labels > 0], |
|
) |
|
if loss is None: |
|
loss = tep_loss |
|
else: |
|
loss = loss + tep_loss |
|
|
|
tcp_loss = None |
|
if topic_category_labels is not None: |
|
num_topic_category_labels = topic_category_labels.sum(dim=1) |
|
if (num_topic_category_labels > 0).any(): |
|
topic_category_labels = topic_category_labels.to( |
|
topic_category_logits.dtype |
|
) / num_topic_category_labels.unsqueeze(-1) |
|
tcp_loss = self.loss_fn( |
|
topic_category_logits[num_topic_category_labels > 0] |
|
/ self.config.entity_temperature, |
|
topic_category_labels[num_topic_category_labels > 0], |
|
) |
|
if loss is None: |
|
loss = tcp_loss |
|
else: |
|
loss = loss + tcp_loss |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [ |
|
logits, |
|
entity_logits, |
|
topic_entity_logits, |
|
topic_category_logits, |
|
outputs.last_hidden_state, |
|
outputs.entity_last_hidden_state, |
|
outputs.hidden_states, |
|
outputs.entity_hidden_states, |
|
outputs.attentions, |
|
] |
|
if v is not None |
|
) |
|
|
|
return UbkeMaskedLMOutput( |
|
loss=loss, |
|
mlm_loss=mlm_loss, |
|
mep_loss=mep_loss, |
|
tep_loss=tep_loss, |
|
tcp_loss=tcp_loss, |
|
logits=logits, |
|
entity_logits=entity_logits, |
|
topic_entity_logits=topic_entity_logits, |
|
topic_category_logits=topic_category_logits, |
|
last_hidden_state=outputs.last_hidden_state, |
|
hidden_states=outputs.hidden_states, |
|
entity_last_hidden_state=outputs.entity_last_hidden_state, |
|
entity_hidden_states=outputs.entity_hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|