Update custom_model_package/custom_model.py
Browse files
custom_model_package/custom_model.py
CHANGED
@@ -42,12 +42,13 @@ class CustomModel(XLMRobertaForSequenceClassification):
|
|
42 |
with torch.no_grad():
|
43 |
cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
|
44 |
sentiment_logits = self.classifier(cls_token_state).squeeze(1)
|
|
|
45 |
if labels is not None:
|
46 |
class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
|
47 |
loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
|
48 |
loss = loss_fct(emotion_logits, labels)
|
49 |
-
return {"loss": loss, "
|
50 |
-
return {"
|
51 |
|
52 |
@classmethod
|
53 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
|
42 |
with torch.no_grad():
|
43 |
cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
|
44 |
sentiment_logits = self.classifier(cls_token_state).squeeze(1)
|
45 |
+
logits = torch.cat([sentiment_logits, emotion_logits], dim=-1)
|
46 |
if labels is not None:
|
47 |
class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
|
48 |
loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
|
49 |
loss = loss_fct(emotion_logits, labels)
|
50 |
+
return {"loss": loss, "logits": logits}
|
51 |
+
return {"logits": logits}
|
52 |
|
53 |
@classmethod
|
54 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|