MNIST Digit Classifier
Model type: Convolutional Neural Network (CNN)
Model Architecture: 3 Convolutional Layers, 1 Adaptive Pooling Layer, 1 Fully Connected Layer
Framework: PyTorch
Task: Image Classification (Digits 0-9)
Model Description
This model is a Convolutional Neural Network (CNN) trained on the MNIST dataset, which consists of handwritten digits (0-9). It is designed to classify images of handwritten digits, making it suitable for applications that require digit recognition, such as form scanning, document analysis, and real-time digit detection.
Model Architecture:
- Convolutional Layers: 3 convolutional layers with ReLU activations.
- Adaptive Pooling: Adaptive Average Pooling is used to ensure dynamic input handling.
- Fully Connected Layer: The output from the convolutional layers is flattened and fed into a fully connected layer that outputs 10 logits corresponding to the digits 0-9.
Training Data
The model was trained on the MNIST dataset, which contains 60,000 training examples and 10,000 test examples of 28x28 grayscale images of handwritten digits.
Data Preprocessing:
- Data Augmentation: Random rotations (up to 10 degrees) and translations (up to 10% shift) were applied to the training data to improve generalization and robustness.
- Normalization: Each pixel was normalized to the range [-1, 1] by using the following normalization parameters:
- Mean: 0.5
- Standard Deviation: 0.5
Intended Use
This model is suitable for:
- Recognizing handwritten digits in real-world applications such as scanned documents, forms, or digit-based input systems.
- Educational purposes to demonstrate digit classification using neural networks.
How to use:
The model can be loaded using PyTorch, and an image can be classified by following this code snippet:
```python import torch from torchvision import transforms from PIL import Image
Load the model
model = YourModelClass() model.load_state_dict(torch.load('mnist_classifier.pth')) model.eval()
Preprocess image
transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])
img = Image.open('path_to_image').convert('L') img_tensor = transform(img).unsqueeze(0)
Predict
with torch.no_grad(): output = model(img_tensor) predicted_label = torch.argmax(output, dim=1).item()
print(f"Predicted Label: {predicted_label}") ```
Evaluation Results
Metrics:
The model achieved the following performance metrics on the MNIST test dataset:
- Accuracy: ~98%
- Loss: Cross-entropy loss during training converged to a low value (~0.15 after 10 epochs).
Noisy Image Performance:
The model was also tested on noisy digit images and successfully classified digits with preprocessing applied (e.g., Gaussian blur and thresholding).
Limitations
- Noisy Inputs: The model might still struggle with images that are heavily noisy or distorted, though preprocessing techniques like Gaussian blur and normalization help mitigate these issues.
- Generalization: The model is designed specifically for MNIST-like digits and might not generalize well to digit styles that are too different from the MNIST dataset (e.g., digits from different cultures or handwriting styles).
Training Details
Hyperparameters:
- Optimizer: Adam
- Learning Rate: 0.001
- Loss Function: Cross-entropy Loss
- Batch Size: 32
- Epochs: 10
- Data Augmentation: Random rotations and translations during training
Ethical Considerations
While this model does not have significant ethical concerns, users should be aware that it is trained on a specific dataset (MNIST) that consists only of simple, grayscale digits. It may not perform well on digits outside of this domain (e.g., digits from other scripts or more complex scenarios).
Model Card Contact
If you have any questions, feedback, or inquiries about this model, feel free to reach out to the author via [[email protected]].