File size: 655 Bytes
4b51a56
 
 
 
 
 
 
 
 
 
 
8b209d1
4b51a56
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch import nn

class ClassificationModel(nn.Module):
    def __init__(self, base_model):
        super(ClassificationModel, self).__init__()
        self.base_model = base_model
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 8),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input_ids, attention_mask):
        hidden_states = self.base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        cls_output = hidden_states[:, 0, :]
        probs = self.classifier(cls_output)
        return probs