gsar78 commited on
Commit
6277793
·
verified ·
1 Parent(s): e27f3d9

Update custom_model_package/custom_model.py

Browse files
Files changed (1) hide show
  1. custom_model_package/custom_model.py +11 -16
custom_model_package/custom_model.py CHANGED
@@ -10,27 +10,25 @@ class CustomConfig(PretrainedConfig):
10
  self.num_emotion_labels = num_emotion_labels
11
 
12
  class CustomModel(XLMRobertaForSequenceClassification):
13
- config_class = CustomConfig
14
-
15
- def __init__(self, config):
16
  super(CustomModel, self).__init__(config)
17
- self.num_emotion_labels = config.num_emotion_labels
18
  self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob)
19
  self.emotion_classifier = nn.Sequential(
20
  nn.Linear(config.hidden_size, 512),
21
  nn.Mish(),
22
  nn.Dropout(0.3),
23
- nn.Linear(512, self.num_emotion_labels)
24
  )
25
  self._init_weights(self.emotion_classifier[0])
26
  self._init_weights(self.emotion_classifier[3])
27
-
28
  def _init_weights(self, module):
29
  if isinstance(module, nn.Linear):
30
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
31
  if module.bias is not None:
32
  module.bias.data.zero_()
33
-
34
  def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None):
35
  outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
36
  sequence_output = outputs[0]
@@ -39,16 +37,13 @@ class CustomModel(XLMRobertaForSequenceClassification):
39
  cls_hidden_states = sequence_output[:, 0, :]
40
  cls_hidden_states = self.dropout_emotion(cls_hidden_states)
41
  emotion_logits = self.emotion_classifier(cls_hidden_states)
42
-
43
- # Sentiment logits from the original classifier
44
- sentiment_logits = self.classifier(cls_hidden_states)
45
-
46
- # Concatenate the sentiment and emotion logits
47
- logits = torch.cat([sentiment_logits, emotion_logits], dim=-1)
48
-
49
  if labels is not None:
50
  class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
51
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
52
  loss = loss_fct(emotion_logits, labels)
53
- return {"loss": loss, "logits": logits}
54
- return {"logits": logits}
 
 
10
  self.num_emotion_labels = num_emotion_labels
11
 
12
  class CustomModel(XLMRobertaForSequenceClassification):
13
+ def __init__(self, config, num_emotion_labels):
 
 
14
  super(CustomModel, self).__init__(config)
15
+ self.num_emotion_labels = num_emotion_labels
16
  self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob)
17
  self.emotion_classifier = nn.Sequential(
18
  nn.Linear(config.hidden_size, 512),
19
  nn.Mish(),
20
  nn.Dropout(0.3),
21
+ nn.Linear(512, num_emotion_labels)
22
  )
23
  self._init_weights(self.emotion_classifier[0])
24
  self._init_weights(self.emotion_classifier[3])
25
+
26
  def _init_weights(self, module):
27
  if isinstance(module, nn.Linear):
28
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
29
  if module.bias is not None:
30
  module.bias.data.zero_()
31
+
32
  def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None):
33
  outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
34
  sequence_output = outputs[0]
 
37
  cls_hidden_states = sequence_output[:, 0, :]
38
  cls_hidden_states = self.dropout_emotion(cls_hidden_states)
39
  emotion_logits = self.emotion_classifier(cls_hidden_states)
40
+ with torch.no_grad():
41
+ cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
42
+ sentiment_logits = self.classifier(cls_token_state).squeeze(1)
 
 
 
 
43
  if labels is not None:
44
  class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
45
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
46
  loss = loss_fct(emotion_logits, labels)
47
+ return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
48
+ return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
49
+