RealKintaro's picture
Init
7f9da02
raw
history blame
1.97 kB
import torch.nn as nn
from transformers import BertModel
class BertClassifier(nn.Module):
"""Bert Model for Classification Tasks.
"""
def __init__(self, freeze_bert=False):
"""
@param bert: a BertModel object
@param classifier: a torch.nn.Module classifier
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
"""
super(BertClassifier, self).__init__()
# hidden size of BERT, hidden size of our classifier, number of labels
D_in, H, D_out = 768, 50, 2
# Instantiate BERT model
self.bert = BertModel.from_pretrained('aubmindlab/bert-base-arabertv02')
# Instantiate an one-layer feed-forward classifier
self.classifier = nn.Sequential(
nn.Linear(D_in, H),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(H, D_out)
)
# Freeze the BERT model
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
"""
Feed input to BERT and the classifier to compute logits.
@param input_ids (torch.Tensor): an input tensor with shape (batch_size,
max_length)
@param attention_mask (torch.Tensor): a tensor that hold attention mask
information with shape (batch_size, max_length)
@return logits (torch.Tensor): an output tensor with shape (batch_size,
num_labels)
"""
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
# Extract the last hidden state of the token `[CLS]` for classification task and feed them to classifier to compute logits
last_hidden_state_cls = outputs[0][:, 0, :]
logits = self.classifier(last_hidden_state_cls)
return logits