gsar78 commited on
Commit
1c3979d
1 Parent(s): f0e224d

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +12 -3
modeling.py CHANGED
@@ -1,9 +1,18 @@
1
- # modeling.py
2
- from transformers import XLMRobertaForSequenceClassification, AutoConfig
3
  from .modeling_custom import CustomModel
4
 
5
  def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
 
6
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
 
 
7
  num_emotion_labels = config.num_emotion_labels
8
- model = CustomModel.from_pretrained(pretrained_model_name_or_path, num_emotion_labels=num_emotion_labels, *model_args, **kwargs)
 
 
 
 
 
 
 
9
  return model
 
1
+ from transformers import AutoConfig
 
2
  from .modeling_custom import CustomModel
3
 
4
  def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
5
+ # Load the configuration
6
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
7
+
8
+ # Fetch the number of emotion labels from the config
9
  num_emotion_labels = config.num_emotion_labels
10
+
11
+ # Load the model with the correct configuration
12
+ model = CustomModel.from_pretrained(
13
+ pretrained_model_name_or_path,
14
+ num_emotion_labels=num_emotion_labels,
15
+ *model_args,
16
+ **kwargs
17
+ )
18
  return model