harry commited on
Commit
0dee387
·
1 Parent(s): aaea685

feat: add prediction script and model file for MNIST digit classification

Browse files
models/mnist_model_lr0.001_bs64_ep20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aead1b3223333f05acf8494c6a73aec8bdaa9e32d3f0c239b16e5e12a3a07a8f
3
+ size 4803144
predict.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # mypy: ignore-errors
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from mnist_classifier.model import MNISTModel
6
+ import torch.nn.functional as F
7
+
8
+ def load_model(model_path):
9
+ """Load the trained model."""
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ model = MNISTModel().to(device)
12
+ model.load_state_dict(torch.load(model_path, weights_only=True))
13
+ model.eval()
14
+ return model, device
15
+
16
+ def preprocess_image(image_path):
17
+ """Preprocess the input image."""
18
+ transform = transforms.Compose([
19
+ transforms.Grayscale(num_output_channels=1),
20
+ transforms.Resize((28, 28)),
21
+ transforms.ToTensor(), # This converts PIL Image to tensor
22
+ transforms.Normalize((0.5,), (0.5,))
23
+ ])
24
+
25
+ image = Image.open(image_path)
26
+ image_tensor = transform(image) # Now image_tensor is already a tensor
27
+ return image_tensor.unsqueeze(0) # type: ignore # Add batch dimension using tensor method
28
+
29
+ def predict(model, image, device):
30
+ """Make prediction on the input image."""
31
+ with torch.no_grad():
32
+ image = image.to(device)
33
+ output = model(image)
34
+ probabilities = F.softmax(output, dim=1)
35
+ pred = output.argmax(dim=1, keepdim=True)
36
+ return pred.item(), probabilities[0]
37
+
38
+ def main():
39
+ # Path to your trained model
40
+ model_path = "./models/mnist_model_lr0.001_bs64_ep10.pth"
41
+
42
+ # Load model
43
+ model, device = load_model(model_path)
44
+
45
+ # Path to input image
46
+ image_path = "./test/image.jpg"
47
+
48
+ # Preprocess image and get prediction
49
+ image = preprocess_image(image_path)
50
+ prediction, probabilities = predict(model, image, device)
51
+
52
+ print(f"Predicted digit: {prediction}")
53
+ print("\nProbabilities for each digit:")
54
+ for digit, prob in enumerate(probabilities):
55
+ print(f"{digit}: {prob.item():.4f}")
56
+
57
+ if __name__ == "__main__":
58
+ main()
test/image.jpg ADDED
torchvision.pyi CHANGED
@@ -6,4 +6,7 @@ class datasets:
6
  class transforms:
7
  Compose: Any
8
  ToTensor: Any
9
- Normalize: Any
 
 
 
 
6
  class transforms:
7
  Compose: Any
8
  ToTensor: Any
9
+ Normalize: Any
10
+ Grayscale: Any
11
+ Resize: Any
12
+