|
|
|
import torch.nn as nn |
|
from transformers import BertModel |
|
|
|
class BertClassifier(nn.Module): |
|
"""Bert Model for Classification Tasks. |
|
""" |
|
def __init__(self, freeze_bert=False): |
|
""" |
|
@param bert: a BertModel object |
|
@param classifier: a torch.nn.Module classifier |
|
@param freeze_bert (bool): Set `False` to fine-tune the BERT model |
|
""" |
|
super(BertClassifier, self).__init__() |
|
|
|
D_in, H, D_out = 768, 50, 2 |
|
|
|
|
|
self.bert = BertModel.from_pretrained('aubmindlab/bert-base-arabertv02') |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(D_in, H), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(H, D_out) |
|
) |
|
|
|
|
|
if freeze_bert: |
|
for param in self.bert.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, input_ids, attention_mask): |
|
""" |
|
Feed input to BERT and the classifier to compute logits. |
|
@param input_ids (torch.Tensor): an input tensor with shape (batch_size, |
|
max_length) |
|
@param attention_mask (torch.Tensor): a tensor that hold attention mask |
|
information with shape (batch_size, max_length) |
|
@return logits (torch.Tensor): an output tensor with shape (batch_size, |
|
num_labels) |
|
""" |
|
|
|
outputs = self.bert(input_ids=input_ids, |
|
attention_mask=attention_mask) |
|
|
|
|
|
last_hidden_state_cls = outputs[0][:, 0, :] |
|
logits = self.classifier(last_hidden_state_cls) |
|
|
|
return logits |