gsar78 commited on
Commit
177a5fe
1 Parent(s): d3dab3b

Update custom_model_package/custom_model.py

Browse files
custom_model_package/custom_model.py CHANGED
@@ -42,12 +42,13 @@ class CustomModel(XLMRobertaForSequenceClassification):
42
  with torch.no_grad():
43
  cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
44
  sentiment_logits = self.classifier(cls_token_state).squeeze(1)
 
45
  if labels is not None:
46
  class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
47
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
48
  loss = loss_fct(emotion_logits, labels)
49
- return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
50
- return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
51
 
52
  @classmethod
53
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
 
42
  with torch.no_grad():
43
  cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
44
  sentiment_logits = self.classifier(cls_token_state).squeeze(1)
45
+ logits = torch.cat([sentiment_logits, emotion_logits], dim=-1)
46
  if labels is not None:
47
  class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
48
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
49
  loss = loss_fct(emotion_logits, labels)
50
+ return {"loss": loss, "logits": logits}
51
+ return {"logits": logits}
52
 
53
  @classmethod
54
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):