File size: 1,571 Bytes
dfe40cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForTokenClassification

class CNNForNER(nn.Module):
    def __init__(self, pretrained_model_name, num_classes, max_length=128):
        super(CNNForNER, self).__init__()
        self.transformer = AutoModelForTokenClassification.from_pretrained(pretrained_model_name)
        self.max_length = max_length

        # Get the number of labels from the pretrained model
        pretrained_num_labels = self.transformer.num_labels

        self.conv1 = nn.Conv1d(in_channels=pretrained_num_labels, out_channels=256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(in_features=128, out_features=num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # Shape: (batch_size, sequence_length, pretrained_num_labels)
        
        # Apply CNN layers
        logits = logits.permute(0, 2, 1)  # Shape: (batch_size, pretrained_num_labels, sequence_length)
        conv1_out = F.relu(self.conv1(logits))
        conv2_out = F.relu(self.conv2(conv1_out))
        conv2_out = self.dropout(conv2_out)
        conv2_out = conv2_out.permute(0, 2, 1)  # Shape: (batch_size, sequence_length, 128)
        final_logits = self.fc(conv2_out)  # Shape: (batch_size, sequence_length, num_classes)
        return final_logits