HellenicSentimentAI_v2 / modeling.py
gsar78's picture
Update modeling.py
1c3979d verified
raw
history blame
611 Bytes
from transformers import AutoConfig
from .modeling_custom import CustomModel
def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
# Load the configuration
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
# Fetch the number of emotion labels from the config
num_emotion_labels = config.num_emotion_labels
# Load the model with the correct configuration
model = CustomModel.from_pretrained(
pretrained_model_name_or_path,
num_emotion_labels=num_emotion_labels,
*model_args,
**kwargs
)
return model