gsar78 commited on
Commit
b4c4b45
1 Parent(s): 1c50ee3

Update custom_model_package/custom_model.py

Browse files
custom_model_package/custom_model.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers import PretrainedConfig, XLMRobertaForSequenceClassification
2
  import torch.nn as nn
3
  import torch
 
4
 
5
  class CustomConfig(PretrainedConfig):
6
  model_type = "custom_model"
@@ -9,7 +10,7 @@ class CustomConfig(PretrainedConfig):
9
  super().__init__(**kwargs)
10
  self.num_emotion_labels = num_emotion_labels
11
 
12
- class CustomModel(PreTrainedModel):
13
  config_class = CustomConfig
14
 
15
  def __init__(self, config, num_emotion_labels):
 
1
  from transformers import PretrainedConfig, XLMRobertaForSequenceClassification
2
  import torch.nn as nn
3
  import torch
4
+ from huggingface_hub import PyTorchModelHubMixin
5
 
6
  class CustomConfig(PretrainedConfig):
7
  model_type = "custom_model"
 
10
  super().__init__(**kwargs)
11
  self.num_emotion_labels = num_emotion_labels
12
 
13
+ class CustomModel(nn.Module, PyTorchModelHubMixin):
14
  config_class = CustomConfig
15
 
16
  def __init__(self, config, num_emotion_labels):