File size: 2,789 Bytes
d7434a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import AutoConfig, AutoModel, PretrainedConfig, XLMRobertaForSequenceClassification
from transformers.models.auto.modeling_auto import auto_class_factory
import torch
import torch.nn as nn

class CustomConfig(PretrainedConfig):
    model_type = "custom_model"
    
    def __init__(self, num_emotion_labels=18, **kwargs):
        super().__init__(**kwargs)
        self.num_emotion_labels = num_emotion_labels

@auto_class_factory("modeling")
class CustomModel(XLMRobertaForSequenceClassification):
    config_class = CustomConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.num_emotion_labels = config.num_emotion_labels
        self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob)
        self.emotion_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, 512),
            nn.Mish(),
            nn.Dropout(0.3),
            nn.Linear(512, self.num_emotion_labels)
        )
        self._init_weights(self.emotion_classifier[0])
        self._init_weights(self.emotion_classifier[3])
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
                
    def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        if len(sequence_output.shape) != 3:
            raise ValueError(f"Expected sequence_output to have 3 dimensions, got {sequence_output.shape}")
        cls_hidden_states = sequence_output[:, 0, :]
        cls_hidden_states = self.dropout_emotion(cls_hidden_states)
        emotion_logits = self.emotion_classifier(cls_hidden_states)
        with torch.no_grad():
            cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
            sentiment_logits = self.classifier(cls_token_state).squeeze(1)
        if labels is not None:
            class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
            loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
            loss = loss_fct(emotion_logits, labels)
            return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
        return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}

# Register the custom configuration and model
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_MAPPING

CONFIG_MAPPING.register("custom_model", CustomConfig)
MODEL_MAPPING.register(CustomConfig, CustomModel)