from abc import ABCMeta import numpy as np import torch from transformers.pytorch_utils import nn import torch.nn.functional as F from src.configuration import BertABSAConfig from transformers import BertModel, BertForSequenceClassification, PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput class BertBaseForSequenceClassification(PreTrainedModel, metaclass=ABCMeta): config_class = BertABSAConfig def __init__(self, config): super(BertBaseForSequenceClassification, self).__init__(config) self.num_classes = config.num_classes self.embed_dim = config.embed_dim self.dropout = nn.Dropout(config.dropout_rate) self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', # noqa output_hidden_states=False, # noqa output_attentions=False, # noqa num_labels=self.num_classes) # noqa print("BERT Model Loaded") def forward(self, input_ids, attention_mask, token_type_ids, labels=None): out = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels) return out class BertLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta): config_class = BertABSAConfig def __init__(self, config): super(BertLSTMForSequenceClassification, self).__init__(config) self.num_classes = config.num_classes self.embed_dim = config.embed_dim self.num_layers = config.num_layers self.hidden_dim_lstm = config.hidden_dim_lstm self.dropout = nn.Dropout(config.dropout_rate) self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=False) print("BERT Model Loaded") self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True) # noqa self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes) def forward(self, input_ids, attention_mask, token_type_ids, labels=None): bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) hidden_states = bert_output["hidden_states"] hidden_states = torch.stack([hidden_states[layer_i][:, 0].squeeze() for layer_i in range(0, self.num_layers)], dim=-1) # noqa hidden_states = hidden_states.view(-1, self.num_layers, self.embed_dim) out, _ = self.lstm(hidden_states, None) out = self.dropout(out[:, -1, :]) logits = self.fc(out) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) out = SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=bert_output.hidden_states, attentions=bert_output.attentions, ) return out class BertAttentionForSequenceClassification(PreTrainedModel, metaclass=ABCMeta): config_class = BertABSAConfig def __init__(self, config): super(BertAttentionForSequenceClassification, self).__init__(config) self.num_classes = config.num_classes self.embed_dim = config.embed_dim self.num_layers = config.num_layers self.fc_hidden = config.fc_hidden self.dropout = nn.Dropout(config.dropout_rate) self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=False) print("BERT Model Loaded") q_t = np.random.normal(loc=0.0, scale=0.1, size=(1, self.embed_dim)) self.q = nn.Parameter(torch.from_numpy(q_t)).float().to(self.device) w_ht = np.random.normal(loc=0.0, scale=0.1, size=(self.embed_dim, self.fc_hidden)) # noqa self.w_h = nn.Parameter(torch.from_numpy(w_ht)).float().to(self.device) self.fc = nn.Linear(self.fc_hidden, self.num_classes) def forward(self, input_ids, attention_mask, token_type_ids, labels=None): bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) hidden_states = bert_output["hidden_states"] hidden_states = torch.stack([hidden_states[layer_i][:, 0].squeeze() for layer_i in range(0, self.num_layers)], dim=-1) # noqa hidden_states = hidden_states.view(-1, self.num_layers, self.embed_dim) out = self.attention(hidden_states) out = self.dropout(out) logits = self.fc(out) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) out = SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=bert_output.hidden_states, attentions=bert_output.attentions, ) return out def attention(self, h): v = torch.matmul(self.q, h.transpose(-2, -1)).squeeze(1) v = F.softmax(v, -1) v_temp = torch.matmul(v.unsqueeze(1), h).transpose(-2, -1) v = torch.matmul(self.w_h.transpose(1, 0), v_temp).squeeze(2) return v