File size: 476 Bytes
9c243c1 |
1 2 3 4 5 6 7 8 9 10 |
# modeling.py
from transformers import XLMRobertaForSequenceClassification, AutoConfig
from .modeling_custom import CustomModel
def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
num_emotion_labels = config.num_emotion_labels
model = CustomModel.from_pretrained(pretrained_model_name_or_path, num_emotion_labels=num_emotion_labels, *model_args, **kwargs)
return model
|