gsar78 commited on
Commit
d7434a5
1 Parent(s): 66ee760

Create custom_model.py

Browse files
Files changed (1) hide show
  1. custom_model.py +59 -0
custom_model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, XLMRobertaForSequenceClassification
2
+ from transformers.models.auto.modeling_auto import auto_class_factory
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class CustomConfig(PretrainedConfig):
7
+ model_type = "custom_model"
8
+
9
+ def __init__(self, num_emotion_labels=18, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.num_emotion_labels = num_emotion_labels
12
+
13
+ @auto_class_factory("modeling")
14
+ class CustomModel(XLMRobertaForSequenceClassification):
15
+ config_class = CustomConfig
16
+
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ self.num_emotion_labels = config.num_emotion_labels
20
+ self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob)
21
+ self.emotion_classifier = nn.Sequential(
22
+ nn.Linear(config.hidden_size, 512),
23
+ nn.Mish(),
24
+ nn.Dropout(0.3),
25
+ nn.Linear(512, self.num_emotion_labels)
26
+ )
27
+ self._init_weights(self.emotion_classifier[0])
28
+ self._init_weights(self.emotion_classifier[3])
29
+
30
+ def _init_weights(self, module):
31
+ if isinstance(module, nn.Linear):
32
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
33
+ if module.bias is not None:
34
+ module.bias.data.zero_()
35
+
36
+ def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None):
37
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
38
+ sequence_output = outputs[0]
39
+ if len(sequence_output.shape) != 3:
40
+ raise ValueError(f"Expected sequence_output to have 3 dimensions, got {sequence_output.shape}")
41
+ cls_hidden_states = sequence_output[:, 0, :]
42
+ cls_hidden_states = self.dropout_emotion(cls_hidden_states)
43
+ emotion_logits = self.emotion_classifier(cls_hidden_states)
44
+ with torch.no_grad():
45
+ cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
46
+ sentiment_logits = self.classifier(cls_token_state).squeeze(1)
47
+ if labels is not None:
48
+ class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
49
+ loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
50
+ loss = loss_fct(emotion_logits, labels)
51
+ return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
52
+ return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
53
+
54
+ # Register the custom configuration and model
55
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
56
+ from transformers.models.auto.modeling_auto import MODEL_MAPPING
57
+
58
+ CONFIG_MAPPING.register("custom_model", CustomConfig)
59
+ MODEL_MAPPING.register(CustomConfig, CustomModel)