File size: 1,855 Bytes
cad2dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

class SentimentCNNModel(nn.Module):
    def __init__(self, transformer_model_name, num_classes, cnn_out_channels=100, cnn_kernel_sizes=[3, 5, 7]):
        super(SentimentCNNModel, self).__init__()
        # Load pre-trained transformer model
        self.transformer = AutoModel.from_pretrained(transformer_model_name)
        
        # CNN layers with multiple kernel sizes
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=self.transformer.config.hidden_size,
                      out_channels=cnn_out_channels,
                      kernel_size=k)
            for k in cnn_kernel_sizes
        ])
        
        # Dropout layer
        self.dropout = nn.Dropout(0.5)
        
        # Fully connected layer
        self.fc = nn.Linear(len(cnn_kernel_sizes) * cnn_out_channels, num_classes)
    
    def forward(self, input_ids, attention_mask):
        # Get hidden states from the transformer model
        transformer_outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = transformer_outputs.last_hidden_state  # Shape: (batch_size, seq_len, hidden_size)
        
        # Transpose for CNN input: (batch_size, hidden_size, seq_len)
        hidden_states = hidden_states.transpose(1, 2)
        
        # Apply convolution and pooling
        conv_outputs = [torch.relu(conv(hidden_states)) for conv in self.convs]
        pooled_outputs = [torch.max(output, dim=2)[0] for output in conv_outputs]
        
        # Concatenate pooled outputs and apply dropout
        cat_output = torch.cat(pooled_outputs, dim=1)
        cat_output = self.dropout(cat_output)
        
        # Final classification
        logits = self.fc(cat_output)
        
        return logits