File size: 1,855 Bytes
cad2dfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
class SentimentCNNModel(nn.Module):
def __init__(self, transformer_model_name, num_classes, cnn_out_channels=100, cnn_kernel_sizes=[3, 5, 7]):
super(SentimentCNNModel, self).__init__()
# Load pre-trained transformer model
self.transformer = AutoModel.from_pretrained(transformer_model_name)
# CNN layers with multiple kernel sizes
self.convs = nn.ModuleList([
nn.Conv1d(in_channels=self.transformer.config.hidden_size,
out_channels=cnn_out_channels,
kernel_size=k)
for k in cnn_kernel_sizes
])
# Dropout layer
self.dropout = nn.Dropout(0.5)
# Fully connected layer
self.fc = nn.Linear(len(cnn_kernel_sizes) * cnn_out_channels, num_classes)
def forward(self, input_ids, attention_mask):
# Get hidden states from the transformer model
transformer_outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = transformer_outputs.last_hidden_state # Shape: (batch_size, seq_len, hidden_size)
# Transpose for CNN input: (batch_size, hidden_size, seq_len)
hidden_states = hidden_states.transpose(1, 2)
# Apply convolution and pooling
conv_outputs = [torch.relu(conv(hidden_states)) for conv in self.convs]
pooled_outputs = [torch.max(output, dim=2)[0] for output in conv_outputs]
# Concatenate pooled outputs and apply dropout
cat_output = torch.cat(pooled_outputs, dim=1)
cat_output = self.dropout(cat_output)
# Final classification
logits = self.fc(cat_output)
return logits
|