HellenicSentimentAI_v2 / modeling.py
gsar78's picture
Upload 4 files
9c243c1 verified
raw
history blame
476 Bytes
# modeling.py
from transformers import XLMRobertaForSequenceClassification, AutoConfig
from .modeling_custom import CustomModel
def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
num_emotion_labels = config.num_emotion_labels
model = CustomModel.from_pretrained(pretrained_model_name_or_path, num_emotion_labels=num_emotion_labels, *model_args, **kwargs)
return model