trek90s
add model
617dc35
raw
history blame contribute delete
No virus
5.48 kB
from abc import ABCMeta
import numpy as np
import torch
from transformers.pytorch_utils import nn
import torch.nn.functional as F
from src.configuration import BertABSAConfig
from transformers import BertModel, BertForSequenceClassification, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
class BertBaseForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
config_class = BertABSAConfig
def __init__(self, config):
super(BertBaseForSequenceClassification, self).__init__(config)
self.num_classes = config.num_classes
self.embed_dim = config.embed_dim
self.dropout = nn.Dropout(config.dropout_rate)
self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', # noqa
output_hidden_states=False, # noqa
output_attentions=False, # noqa
num_labels=self.num_classes) # noqa
print("BERT Model Loaded")
def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
out = self.bert(input_ids=input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids, labels=labels)
return out
class BertLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
config_class = BertABSAConfig
def __init__(self, config):
super(BertLSTMForSequenceClassification, self).__init__(config)
self.num_classes = config.num_classes
self.embed_dim = config.embed_dim
self.num_layers = config.num_layers
self.hidden_dim_lstm = config.hidden_dim_lstm
self.dropout = nn.Dropout(config.dropout_rate)
self.bert = BertModel.from_pretrained('bert-base-uncased',
output_hidden_states=True,
output_attentions=False)
print("BERT Model Loaded")
self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True) # noqa
self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes)
def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_states = bert_output["hidden_states"]
hidden_states = torch.stack([hidden_states[layer_i][:, 0].squeeze()
for layer_i in range(0, self.num_layers)], dim=-1) # noqa
hidden_states = hidden_states.view(-1, self.num_layers, self.embed_dim)
out, _ = self.lstm(hidden_states, None)
out = self.dropout(out[:, -1, :])
logits = self.fc(out)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
out = SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=bert_output.hidden_states,
attentions=bert_output.attentions,
)
return out
class BertAttentionForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
config_class = BertABSAConfig
def __init__(self, config):
super(BertAttentionForSequenceClassification, self).__init__(config)
self.num_classes = config.num_classes
self.embed_dim = config.embed_dim
self.num_layers = config.num_layers
self.fc_hidden = config.fc_hidden
self.dropout = nn.Dropout(config.dropout_rate)
self.bert = BertModel.from_pretrained('bert-base-uncased',
output_hidden_states=True,
output_attentions=False)
print("BERT Model Loaded")
q_t = np.random.normal(loc=0.0, scale=0.1, size=(1, self.embed_dim))
self.q = nn.Parameter(torch.from_numpy(q_t)).float().to(self.device)
w_ht = np.random.normal(loc=0.0, scale=0.1, size=(self.embed_dim, self.fc_hidden)) # noqa
self.w_h = nn.Parameter(torch.from_numpy(w_ht)).float().to(self.device)
self.fc = nn.Linear(self.fc_hidden, self.num_classes)
def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_states = bert_output["hidden_states"]
hidden_states = torch.stack([hidden_states[layer_i][:, 0].squeeze()
for layer_i in range(0, self.num_layers)], dim=-1) # noqa
hidden_states = hidden_states.view(-1, self.num_layers, self.embed_dim)
out = self.attention(hidden_states)
out = self.dropout(out)
logits = self.fc(out)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
out = SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=bert_output.hidden_states,
attentions=bert_output.attentions,
)
return out
def attention(self, h):
v = torch.matmul(self.q, h.transpose(-2, -1)).squeeze(1)
v = F.softmax(v, -1)
v_temp = torch.matmul(v.unsqueeze(1), h).transpose(-2, -1)
v = torch.matmul(self.w_h.transpose(1, 0), v_temp).squeeze(2)
return v