File size: 611 Bytes
1c3979d
9c243c1
 
 
1c3979d
9c243c1
1c3979d
 
9c243c1
1c3979d
 
 
 
 
 
 
 
9c243c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import AutoConfig
from .modeling_custom import CustomModel

def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
    # Load the configuration
    config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
    
    # Fetch the number of emotion labels from the config
    num_emotion_labels = config.num_emotion_labels
    
    # Load the model with the correct configuration
    model = CustomModel.from_pretrained(
        pretrained_model_name_or_path,
        num_emotion_labels=num_emotion_labels,
        *model_args,
        **kwargs
    )
    return model