jerilseb commited on
Commit
c55f75b
·
verified ·
1 Parent(s): 0d98f49

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -5
README.md CHANGED
@@ -5,15 +5,50 @@ license: mit
5
  ## Usage
6
 
7
  ```python
8
- state_dict = torch.load('model.pth', map_location='cpu')
9
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  model.eval()
11
 
12
- def predict(im):
13
- x = torch.tensor(im, dtype=torch.float32)
 
 
 
 
 
14
 
 
 
 
15
  with torch.no_grad():
16
- out = model(x)
17
 
18
  probabilities = F.softmax(out[0], dim=0)
 
 
19
  ```
 
5
  ## Usage
6
 
7
  ```python
8
+ import torch
9
+ from torch import nn
10
+ import torchvision.transforms as transforms
11
+ import torch.nn.functional as F
12
+ from pathlib import Path
13
+
14
+ LABELS = Path("classes.txt").read_text().splitlines()
15
+ num_classes = len(LABELS)
16
+
17
+ model = nn.Sequential(
18
+ nn.Conv2d(1, 64, 3, padding="same"),
19
+ nn.ReLU(),
20
+ nn.MaxPool2d(2),
21
+ nn.Conv2d(64, 128, 3, padding="same"),
22
+ nn.ReLU(),
23
+ nn.MaxPool2d(2),
24
+ nn.Conv2d(128, 256, 3, padding="same"),
25
+ nn.ReLU(),
26
+ nn.MaxPool2d(2),
27
+ nn.Flatten(),
28
+ nn.Linear(2304, 512),
29
+ nn.ReLU(),
30
+ nn.Linear(512, num_classes),
31
+ )
32
+
33
+ state_dict = torch.load("model.pth", map_location="cpu")
34
+ model.load_state_dict(state_dict)
35
  model.eval()
36
 
37
+ transform = transforms.Compose(
38
+ [
39
+ transforms.Resize((28, 28)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize((0.5,), (0.5,)),
42
+ ]
43
+ )
44
 
45
+ def predict(image):
46
+ image = image['composite']
47
+ tensor = transform(image).unsqueeze(0)
48
  with torch.no_grad():
49
+ out = model(tensor)
50
 
51
  probabilities = F.softmax(out[0], dim=0)
52
+ values, indices = torch.topk(probabilities, 5)
53
+ print(values, indices)
54
  ```