from transformers import AutoConfig | |
from .modeling_custom import CustomModel | |
def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): | |
# Load the configuration | |
config = AutoConfig.from_pretrained(pretrained_model_name_or_path) | |
# Fetch the number of emotion labels from the config | |
num_emotion_labels = config.num_emotion_labels | |
# Load the model with the correct configuration | |
model = CustomModel.from_pretrained( | |
pretrained_model_name_or_path, | |
num_emotion_labels=num_emotion_labels, | |
*model_args, | |
**kwargs | |
) | |
return model | |