|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModel |
|
|
|
class SentimentCNNModel(nn.Module): |
|
def __init__(self, transformer_model_name, num_classes, cnn_out_channels=100, cnn_kernel_sizes=[3, 5, 7]): |
|
super(SentimentCNNModel, self).__init__() |
|
|
|
self.transformer = AutoModel.from_pretrained(transformer_model_name) |
|
|
|
|
|
self.convs = nn.ModuleList([ |
|
nn.Conv1d(in_channels=self.transformer.config.hidden_size, |
|
out_channels=cnn_out_channels, |
|
kernel_size=k) |
|
for k in cnn_kernel_sizes |
|
]) |
|
|
|
|
|
self.dropout = nn.Dropout(0.5) |
|
|
|
|
|
self.fc = nn.Linear(len(cnn_kernel_sizes) * cnn_out_channels, num_classes) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
transformer_outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) |
|
hidden_states = transformer_outputs.last_hidden_state |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2) |
|
|
|
|
|
conv_outputs = [torch.relu(conv(hidden_states)) for conv in self.convs] |
|
pooled_outputs = [torch.max(output, dim=2)[0] for output in conv_outputs] |
|
|
|
|
|
cat_output = torch.cat(pooled_outputs, dim=1) |
|
cat_output = self.dropout(cat_output) |
|
|
|
|
|
logits = self.fc(cat_output) |
|
|
|
return logits |
|
|