|
|
|
# MNIST Digit Classifier |
|
|
|
**Model type**: Convolutional Neural Network (CNN) |
|
**Model Architecture**: 3 Convolutional Layers, 1 Adaptive Pooling Layer, 1 Fully Connected Layer |
|
**Framework**: PyTorch |
|
**Task**: Image Classification (Digits 0-9) |
|
|
|
## Model Description |
|
|
|
This model is a Convolutional Neural Network (CNN) trained on the MNIST dataset, which consists of handwritten digits (0-9). It is designed to classify images of handwritten digits, making it suitable for applications that require digit recognition, such as form scanning, document analysis, and real-time digit detection. |
|
|
|
### Model Architecture: |
|
- **Convolutional Layers**: 3 convolutional layers with ReLU activations. |
|
- **Adaptive Pooling**: Adaptive Average Pooling is used to ensure dynamic input handling. |
|
- **Fully Connected Layer**: The output from the convolutional layers is flattened and fed into a fully connected layer that outputs 10 logits corresponding to the digits 0-9. |
|
|
|
## Training Data |
|
|
|
The model was trained on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/), which contains 60,000 training examples and 10,000 test examples of 28x28 grayscale images of handwritten digits. |
|
|
|
### Data Preprocessing: |
|
- **Data Augmentation**: Random rotations (up to 10 degrees) and translations (up to 10% shift) were applied to the training data to improve generalization and robustness. |
|
- **Normalization**: Each pixel was normalized to the range [-1, 1] by using the following normalization parameters: |
|
- Mean: 0.5 |
|
- Standard Deviation: 0.5 |
|
|
|
## Intended Use |
|
|
|
This model is suitable for: |
|
- Recognizing handwritten digits in real-world applications such as scanned documents, forms, or digit-based input systems. |
|
- Educational purposes to demonstrate digit classification using neural networks. |
|
|
|
**How to use**: |
|
The model can be loaded using PyTorch, and an image can be classified by following this code snippet: |
|
|
|
\`\`\`python |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
# Load the model |
|
model = YourModelClass() |
|
model.load_state_dict(torch.load('mnist_classifier.pth')) |
|
model.eval() |
|
|
|
# Preprocess image |
|
transform = transforms.Compose([ |
|
transforms.Resize((28, 28)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
|
|
img = Image.open('path_to_image').convert('L') |
|
img_tensor = transform(img).unsqueeze(0) |
|
|
|
# Predict |
|
with torch.no_grad(): |
|
output = model(img_tensor) |
|
predicted_label = torch.argmax(output, dim=1).item() |
|
|
|
print(f"Predicted Label: {predicted_label}") |
|
\`\`\` |
|
|
|
## Evaluation Results |
|
|
|
### Metrics: |
|
The model achieved the following performance metrics on the MNIST test dataset: |
|
- **Accuracy**: ~98% |
|
- **Loss**: Cross-entropy loss during training converged to a low value (~0.15 after 10 epochs). |
|
|
|
### Noisy Image Performance: |
|
The model was also tested on noisy digit images and successfully classified digits with preprocessing applied (e.g., Gaussian blur and thresholding). |
|
|
|
## Limitations |
|
|
|
- **Noisy Inputs**: The model might still struggle with images that are heavily noisy or distorted, though preprocessing techniques like Gaussian blur and normalization help mitigate these issues. |
|
- **Generalization**: The model is designed specifically for MNIST-like digits and might not generalize well to digit styles that are too different from the MNIST dataset (e.g., digits from different cultures or handwriting styles). |
|
|
|
## Training Details |
|
|
|
### Hyperparameters: |
|
- **Optimizer**: Adam |
|
- Learning Rate: 0.001 |
|
- **Loss Function**: Cross-entropy Loss |
|
- **Batch Size**: 32 |
|
- **Epochs**: 10 |
|
- **Data Augmentation**: Random rotations and translations during training |
|
|
|
## Ethical Considerations |
|
|
|
While this model does not have significant ethical concerns, users should be aware that it is trained on a specific dataset (MNIST) that consists only of simple, grayscale digits. It may not perform well on digits outside of this domain (e.g., digits from other scripts or more complex scenarios). |
|
|
|
## Model Card Contact |
|
|
|
If you have any questions, feedback, or inquiries about this model, feel free to reach out to the author via [[email protected]]. |
|
|