|
from torch import nn |
|
from transformers import BertForMaskedLM, PreTrainedModel |
|
|
|
from src.config import PunctuationBertConfig |
|
|
|
|
|
class BertForPunctuation(PreTrainedModel): |
|
config_class = PunctuationBertConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
segment_size = config.backward_context + config.forward_context + 2 |
|
bert_vocab_size = config.vocab_size |
|
self.bert = BertForMaskedLM(config) |
|
self.bn = nn.BatchNorm1d(segment_size * bert_vocab_size) |
|
self.fc = nn.Linear(segment_size * bert_vocab_size, config.output_size) |
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
def forward(self, x): |
|
x = self.bert(x)[0] |
|
x = x.view(x.shape[0], -1) |
|
x = self.fc(self.dropout(self.bn(x))) |
|
return x |
|
|