yifan's picture
Add model definition and pytorch model
e0ca997
__author__ = "Yifan Zhang ([email protected])"
__copyright__ = "Copyright (C) 2021, Qatar Computing Research Institute, HBKU, Doha"
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn.functional import sigmoid
from transformers import BertPreTrainedModel, BertModel
from transformers.file_utils import ModelOutput
TOKEN_TAGS = (
"<PAD>", "O",
"Name_Calling,Labeling", "Repetition", "Slogans", "Appeal_to_fear-prejudice", "Doubt",
"Exaggeration,Minimisation", "Flag-Waving", "Loaded_Language",
"Reductio_ad_hitlerum", "Bandwagon",
"Causal_Oversimplification", "Obfuscation,Intentional_Vagueness,Confusion", "Appeal_to_Authority", "Black-and-White_Fallacy",
"Thought-terminating_Cliches", "Red_Herring", "Straw_Men", "Whataboutism"
)
SEQUENCE_TAGS = ("Non-prop", "Prop")
@dataclass
class TokenAndSequenceJointClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
token_logits: torch.FloatTensor = None
sequence_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BertForTokenAndSequenceJointClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_token_labels = 20
self.num_sequence_labels = 2
self.token_tags = TOKEN_TAGS
self.sequence_tags = SEQUENCE_TAGS
self.alpha = 0.9
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.ModuleList([
nn.Linear(config.hidden_size, self.num_token_labels),
nn.Linear(config.hidden_size, self.num_sequence_labels),
])
self.masking_gate = nn.Linear(2, 1)
self.init_weights()
self.merge_classifier_1 = nn.Linear(self.num_token_labels + self.num_sequence_labels, self.num_token_labels)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=True,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
pooler_output = outputs[1]
sequence_output = self.dropout(sequence_output)
token_logits = self.classifier[0](sequence_output)
pooler_output = self.dropout(pooler_output)
sequence_logits = self.classifier[1](pooler_output)
gate = torch.sigmoid(self.masking_gate(sequence_logits))
gates = gate.unsqueeze(1).repeat(1, token_logits.size()[1], token_logits.size()[2])
weighted_token_logits = torch.mul(gates, token_logits)
logits = [weighted_token_logits, sequence_logits]
loss = None
if labels is not None:
criterion = nn.CrossEntropyLoss(ignore_index=0)
binary_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3932/14263]).cuda())
loss_fct = CrossEntropyLoss()
weighted_token_logits = weighted_token_logits.view(-1, weighted_token_logits.shape[-1])
sequence_logits = sequence_logits.view(-1, sequence_logits.shape[-1])
token_loss = criterion(weighted_token_logits, labels)
sequence_label = torch.LongTensor([1] if any([label > 0 for label in labels]) else [0])
sequence_loss = binary_criterion(sequence_logits, sequence_label)
loss = self.alpha*loss[0] + (1-self.alpha)*loss[1]
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenAndSequenceJointClassifierOutput(
loss=loss,
token_logits=weighted_token_logits,
sequence_logits=sequence_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)