File size: 611 Bytes
1c3979d 9c243c1 1c3979d 9c243c1 1c3979d 9c243c1 1c3979d 9c243c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from transformers import AutoConfig
from .modeling_custom import CustomModel
def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
# Load the configuration
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
# Fetch the number of emotion labels from the config
num_emotion_labels = config.num_emotion_labels
# Load the model with the correct configuration
model = CustomModel.from_pretrained(
pretrained_model_name_or_path,
num_emotion_labels=num_emotion_labels,
*model_args,
**kwargs
)
return model
|