gsar78 commited on
Commit
10700f0
1 Parent(s): 6277793

Update custom_model_package/custom_model.py

Browse files
custom_model_package/custom_model.py CHANGED
@@ -47,3 +47,16 @@ class CustomModel(XLMRobertaForSequenceClassification):
47
  return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
48
  return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
48
  return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
49
 
50
+ @classmethod
51
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
52
+ config = kwargs.pop('config', None)
53
+ if config is None:
54
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
55
+ num_emotion_labels = kwargs.pop('num_emotion_labels', config.num_emotion_labels)
56
+ model = super(CustomModel, cls).from_pretrained(pretrained_model_name_or_path, config=config, *model_args, **kwargs)
57
+ model.num_emotion_labels = num_emotion_labels
58
+ return model
59
+
60
+ # Register the custom configuration and model
61
+ CONFIG_MAPPING.update({"custom_model": CustomConfig})
62
+ MODEL_MAPPING.update({"custom_model": CustomModel})