|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModelForTokenClassification |
|
|
|
class CNNForNER(nn.Module): |
|
def __init__(self, pretrained_model_name, num_classes, max_length=128): |
|
super(CNNForNER, self).__init__() |
|
self.transformer = AutoModelForTokenClassification.from_pretrained(pretrained_model_name) |
|
self.max_length = max_length |
|
|
|
|
|
pretrained_num_labels = self.transformer.num_labels |
|
|
|
self.conv1 = nn.Conv1d(in_channels=pretrained_num_labels, out_channels=256, kernel_size=3, padding=1) |
|
self.conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1) |
|
self.dropout = nn.Dropout(0.3) |
|
self.fc = nn.Linear(in_features=128, out_features=num_classes) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
|
|
logits = logits.permute(0, 2, 1) |
|
conv1_out = F.relu(self.conv1(logits)) |
|
conv2_out = F.relu(self.conv2(conv1_out)) |
|
conv2_out = self.dropout(conv2_out) |
|
conv2_out = conv2_out.permute(0, 2, 1) |
|
final_logits = self.fc(conv2_out) |
|
return final_logits |
|
|