UBKE-LUKE / modeling_ubke.py
KenyaNonaka0210's picture
ok
28698b8
raw
history blame
9.57 kB
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.models.luke.modeling_luke import (
EntityPredictionHead,
LukeLMHead,
LukeModel,
)
from transformers.utils import ModelOutput
from .configuration_ubke import UbkeConfig
@dataclass
class UbkeMaskedLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
mlm_loss: Optional[torch.FloatTensor] = None
mep_loss: Optional[torch.FloatTensor] = None
tep_loss: Optional[torch.FloatTensor] = None
tcp_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
entity_logits: Optional[torch.FloatTensor] = None
topic_entity_logits: torch.FloatTensor = None
topic_category_logits: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
entity_last_hidden_state: torch.FloatTensor = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class UbkePreTrainedModel(PreTrainedModel):
config_class = UbkeConfig
base_model_prefix = "luke"
supports_gradient_checkpointing = True
_no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
if module.embedding_dim == 1: # embedding for bias parameters
module.weight.data.zero_()
else:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class UbkeForMaskedLM(UbkePreTrainedModel):
_tied_weights_keys = [
"lm_head.decoder.weight",
"lm_head.decoder.bias",
"entity_predictions.decoder.weight",
]
def __init__(self, config: UbkeConfig):
super().__init__(config)
self.luke = LukeModel(config)
if self.config.normalize_entity_embeddings:
self.luke.entity_embeddings.entity_embeddings = nn.Embedding(
config.entity_vocab_size,
config.entity_emb_size,
padding_idx=0,
max_norm=1.0,
)
self.lm_head = LukeLMHead(config)
self.entity_predictions = EntityPredictionHead(config)
self.loss_fn = nn.CrossEntropyLoss()
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
super().tie_weights()
self._tie_or_clone_weights(
self.entity_predictions.decoder,
self.luke.entity_embeddings.entity_embeddings,
)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings: nn.Module):
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.LongTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
entity_labels: Optional[torch.LongTensor] = None,
topic_entity_labels: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, UbkeMaskedLMOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
loss = None
mlm_loss = None
logits = self.lm_head(outputs.last_hidden_state)
if labels is not None:
labels = labels.to(logits.device)
mlm_loss = self.loss_fn(
logits.view(-1, self.config.vocab_size), labels.view(-1)
)
if loss is None:
loss = mlm_loss
mep_loss = None
entity_logits = None
if outputs.entity_last_hidden_state is not None:
entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
if entity_labels is not None:
mep_loss = self.loss_fn(
entity_logits.view(-1, self.config.entity_vocab_size)
/ self.config.entity_temperature,
entity_labels.view(-1),
)
if loss is None:
loss = mep_loss
else:
loss = loss + mep_loss
topic_entity_logits = self.entity_predictions(outputs.last_hidden_state[:, 0])
topic_category_logits = None
if self.config.num_category_entities > 0:
topic_category_logits = topic_entity_logits[
:, -self.config.num_category_entities :
]
topic_entity_logits = topic_entity_logits[
:, : -self.config.num_category_entities
]
topic_category_labels = None
if topic_entity_labels is not None and self.config.num_category_entities > 0:
topic_category_labels = topic_entity_labels[
:, -self.config.num_category_entities :
]
topic_entity_labels = topic_entity_labels[
:, : -self.config.num_category_entities
]
tep_loss = None
if topic_entity_labels is not None:
num_topic_entity_labels = topic_entity_labels.sum(dim=1)
if (num_topic_entity_labels > 0).any():
topic_entity_labels = topic_entity_labels.to(
topic_entity_logits.dtype
) / num_topic_entity_labels.unsqueeze(-1)
tep_loss = self.loss_fn(
topic_entity_logits[num_topic_entity_labels > 0]
/ self.config.entity_temperature,
topic_entity_labels[num_topic_entity_labels > 0],
)
if loss is None:
loss = tep_loss
else:
loss = loss + tep_loss
tcp_loss = None
if topic_category_labels is not None:
num_topic_category_labels = topic_category_labels.sum(dim=1)
if (num_topic_category_labels > 0).any():
topic_category_labels = topic_category_labels.to(
topic_category_logits.dtype
) / num_topic_category_labels.unsqueeze(-1)
tcp_loss = self.loss_fn(
topic_category_logits[num_topic_category_labels > 0]
/ self.config.entity_temperature,
topic_category_labels[num_topic_category_labels > 0],
)
if loss is None:
loss = tcp_loss
else:
loss = loss + tcp_loss
if not return_dict:
return tuple(
v
for v in [
logits,
entity_logits,
topic_entity_logits,
topic_category_logits,
outputs.last_hidden_state,
outputs.entity_last_hidden_state,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
]
if v is not None
)
return UbkeMaskedLMOutput(
loss=loss,
mlm_loss=mlm_loss,
mep_loss=mep_loss,
tep_loss=tep_loss,
tcp_loss=tcp_loss,
logits=logits,
entity_logits=entity_logits,
topic_entity_logits=topic_entity_logits,
topic_category_logits=topic_category_logits,
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
entity_last_hidden_state=outputs.entity_last_hidden_state,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)