File size: 1,571 Bytes
dfe40cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForTokenClassification
class CNNForNER(nn.Module):
def __init__(self, pretrained_model_name, num_classes, max_length=128):
super(CNNForNER, self).__init__()
self.transformer = AutoModelForTokenClassification.from_pretrained(pretrained_model_name)
self.max_length = max_length
# Get the number of labels from the pretrained model
pretrained_num_labels = self.transformer.num_labels
self.conv1 = nn.Conv1d(in_channels=pretrained_num_labels, out_channels=256, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
self.dropout = nn.Dropout(0.3)
self.fc = nn.Linear(in_features=128, out_features=num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # Shape: (batch_size, sequence_length, pretrained_num_labels)
# Apply CNN layers
logits = logits.permute(0, 2, 1) # Shape: (batch_size, pretrained_num_labels, sequence_length)
conv1_out = F.relu(self.conv1(logits))
conv2_out = F.relu(self.conv2(conv1_out))
conv2_out = self.dropout(conv2_out)
conv2_out = conv2_out.permute(0, 2, 1) # Shape: (batch_size, sequence_length, 128)
final_logits = self.fc(conv2_out) # Shape: (batch_size, sequence_length, num_classes)
return final_logits
|