im2
commited on
Commit
•
92ebefc
1
Parent(s):
061a822
update model card
Browse files
README.md
CHANGED
@@ -1,5 +1,96 @@
|
|
1 |
-
Changes from previous author:
|
2 |
-
- Updated Architecture: Using AdaptiveAvgPool2d ensures that the fully connected layer receives a consistent input size, regardless of the input dimensions.
|
3 |
-
- Data Augmentation: Training with rotated and shifted images ensures the model becomes more robust to variations, improving generalization.
|
4 |
-
- Noise Reduction: Preprocessing the image by removing noise helps the model focus on the digit itself.
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
+
# MNIST Digit Classifier
|
3 |
+
|
4 |
+
**Model type**: Convolutional Neural Network (CNN)
|
5 |
+
**Model Architecture**: 3 Convolutional Layers, 1 Adaptive Pooling Layer, 1 Fully Connected Layer
|
6 |
+
**Framework**: PyTorch
|
7 |
+
**Task**: Image Classification (Digits 0-9)
|
8 |
+
|
9 |
+
## Model Description
|
10 |
+
|
11 |
+
This model is a Convolutional Neural Network (CNN) trained on the MNIST dataset, which consists of handwritten digits (0-9). It is designed to classify images of handwritten digits, making it suitable for applications that require digit recognition, such as form scanning, document analysis, and real-time digit detection.
|
12 |
+
|
13 |
+
### Model Architecture:
|
14 |
+
- **Convolutional Layers**: 3 convolutional layers with ReLU activations.
|
15 |
+
- **Adaptive Pooling**: Adaptive Average Pooling is used to ensure dynamic input handling.
|
16 |
+
- **Fully Connected Layer**: The output from the convolutional layers is flattened and fed into a fully connected layer that outputs 10 logits corresponding to the digits 0-9.
|
17 |
+
|
18 |
+
## Training Data
|
19 |
+
|
20 |
+
The model was trained on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/), which contains 60,000 training examples and 10,000 test examples of 28x28 grayscale images of handwritten digits.
|
21 |
+
|
22 |
+
### Data Preprocessing:
|
23 |
+
- **Data Augmentation**: Random rotations (up to 10 degrees) and translations (up to 10% shift) were applied to the training data to improve generalization and robustness.
|
24 |
+
- **Normalization**: Each pixel was normalized to the range [-1, 1] by using the following normalization parameters:
|
25 |
+
- Mean: 0.5
|
26 |
+
- Standard Deviation: 0.5
|
27 |
+
|
28 |
+
## Intended Use
|
29 |
+
|
30 |
+
This model is suitable for:
|
31 |
+
- Recognizing handwritten digits in real-world applications such as scanned documents, forms, or digit-based input systems.
|
32 |
+
- Educational purposes to demonstrate digit classification using neural networks.
|
33 |
+
|
34 |
+
**How to use**:
|
35 |
+
The model can be loaded using PyTorch, and an image can be classified by following this code snippet:
|
36 |
+
|
37 |
+
\`\`\`python
|
38 |
+
import torch
|
39 |
+
from torchvision import transforms
|
40 |
+
from PIL import Image
|
41 |
+
|
42 |
+
# Load the model
|
43 |
+
model = YourModelClass()
|
44 |
+
model.load_state_dict(torch.load('mnist_classifier.pth'))
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
# Preprocess image
|
48 |
+
transform = transforms.Compose([
|
49 |
+
transforms.Resize((28, 28)),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize((0.5,), (0.5,))
|
52 |
+
])
|
53 |
+
|
54 |
+
img = Image.open('path_to_image').convert('L')
|
55 |
+
img_tensor = transform(img).unsqueeze(0)
|
56 |
+
|
57 |
+
# Predict
|
58 |
+
with torch.no_grad():
|
59 |
+
output = model(img_tensor)
|
60 |
+
predicted_label = torch.argmax(output, dim=1).item()
|
61 |
+
|
62 |
+
print(f"Predicted Label: {predicted_label}")
|
63 |
+
\`\`\`
|
64 |
+
|
65 |
+
## Evaluation Results
|
66 |
+
|
67 |
+
### Metrics:
|
68 |
+
The model achieved the following performance metrics on the MNIST test dataset:
|
69 |
+
- **Accuracy**: ~98%
|
70 |
+
- **Loss**: Cross-entropy loss during training converged to a low value (~0.15 after 10 epochs).
|
71 |
+
|
72 |
+
### Noisy Image Performance:
|
73 |
+
The model was also tested on noisy digit images and successfully classified digits with preprocessing applied (e.g., Gaussian blur and thresholding).
|
74 |
+
|
75 |
+
## Limitations
|
76 |
+
|
77 |
+
- **Noisy Inputs**: The model might still struggle with images that are heavily noisy or distorted, though preprocessing techniques like Gaussian blur and normalization help mitigate these issues.
|
78 |
+
- **Generalization**: The model is designed specifically for MNIST-like digits and might not generalize well to digit styles that are too different from the MNIST dataset (e.g., digits from different cultures or handwriting styles).
|
79 |
+
|
80 |
+
## Training Details
|
81 |
+
|
82 |
+
### Hyperparameters:
|
83 |
+
- **Optimizer**: Adam
|
84 |
+
- Learning Rate: 0.001
|
85 |
+
- **Loss Function**: Cross-entropy Loss
|
86 |
+
- **Batch Size**: 32
|
87 |
+
- **Epochs**: 10
|
88 |
+
- **Data Augmentation**: Random rotations and translations during training
|
89 |
+
|
90 |
+
## Ethical Considerations
|
91 |
+
|
92 |
+
While this model does not have significant ethical concerns, users should be aware that it is trained on a specific dataset (MNIST) that consists only of simple, grayscale digits. It may not perform well on digits outside of this domain (e.g., digits from other scripts or more complex scenarios).
|
93 |
+
|
94 |
+
## Model Card Contact
|
95 |
+
|
96 |
+
If you have any questions, feedback, or inquiries about this model, feel free to reach out to the author via [[email protected]].
|