jerilseb commited on
Commit
482076b
·
verified ·
1 Parent(s): 58ed48f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -1
README.md CHANGED
@@ -5,6 +5,41 @@ tags: []
5
 
6
  # Usage
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  Register the model
9
 
10
  ```python
@@ -13,7 +48,8 @@ from transformers import AutoConfig, AutoModel
13
  AutoConfig.register("mnist_classifier", MNISTConfig)
14
  AutoModel.register(MNISTConfig, MNISTClassifier)
15
  ```
16
- Inference
 
17
  ```python
18
  from transformers import AutoConfig, AutoModel
19
  import torch
 
5
 
6
  # Usage
7
 
8
+ Define the model and config
9
+ ```python
10
+ from transformers import PreTrainedModel, PretrainedConfig
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ class MNISTConfig(PretrainedConfig):
15
+ model_type = "mnist_classifier"
16
+
17
+ def __init__(self, input_size=784, hidden_size1=1024, hidden_size2=512, num_labels=10, **kwargs):
18
+ super().__init__(**kwargs)
19
+ self.input_size = input_size
20
+ self.hidden_size1 = hidden_size1
21
+ self.hidden_size2 = hidden_size2
22
+ self.num_labels = num_labels
23
+
24
+ class MNISTClassifier(PreTrainedModel):
25
+ config_class = MNISTConfig
26
+
27
+ def __init__(self, config):
28
+ super().__init__(config)
29
+ self.layer1 = nn.Linear(config.input_size, config.hidden_size1)
30
+ self.layer2 = nn.Linear(config.hidden_size1, config.hidden_size2)
31
+ self.layer3 = nn.Linear(config.hidden_size2, config.num_labels)
32
+
33
+ def forward(self, pixel_values):
34
+ inputs = pixel_values.view(-1, self.config.input_size)
35
+ outputs = self.layer1(inputs)
36
+ outputs = F.leaky_relu(outputs)
37
+ outputs = self.layer2(outputs)
38
+ outputs = F.leaky_relu(outputs)
39
+ outputs = self.layer3(outputs)
40
+ return outputs
41
+ ```
42
+
43
  Register the model
44
 
45
  ```python
 
48
  AutoConfig.register("mnist_classifier", MNISTConfig)
49
  AutoModel.register(MNISTConfig, MNISTClassifier)
50
  ```
51
+
52
+ Run Inference
53
  ```python
54
  from transformers import AutoConfig, AutoModel
55
  import torch