MNIST Digit Classifier with Noise Reduction
Model type: Convolutional Neural Network (CNN)
Model Architecture: 3 Convolutional Layers, 1 Adaptive Pooling Layer, 1 Fully Connected Layer
Framework: PyTorch
Task: Image Classification (Digits 0-9)
Model Description
This model is a Convolutional Neural Network (CNN) trained on the MNIST dataset and designed to classify handwritten digits from 0 to 9. The model uses data augmentation to improve its robustness, especially for noisy or rotated images. The preprocessing step includes Gaussian blur for noise reduction, making the model more resilient to outliers and noisy digit inputs.
Model Architecture:
- Convolutional Layers: 3 convolutional layers with ReLU activations.
- Adaptive Pooling: Adaptive Average Pooling to ensure the model handles dynamic input sizes.
- Fully Connected Layer: The output from the convolutional layers is flattened and passed through a fully connected layer to predict the digit.
Training Data
The model was trained on the MNIST dataset, which consists of 60,000 training images and 10,000 test images of handwritten digits. The images are 28x28 pixels in grayscale.
Data Preprocessing:
- Data Augmentation: Random rotations (up to 10 degrees) and random translations (up to 10%) were applied during training to make the model more robust to variations.
- Normalization: The pixel values were normalized to the range [-1, 1].
Intended Use
This model is designed for:
- Recognizing handwritten digits in applications like form scanning, document analysis, or real-time digit detection.
- Educational purposes to demonstrate CNN-based image classification.
How to use the model
You can load the trained model in PyTorch and use it to classify digit images as shown below:
import torch
from torchvision import transforms
from PIL import Image
# Load the model
model = ImageClassifier()
model.load_state_dict(torch.load('mnist_classifier.pth'))
model.eval()
# Preprocess an input image
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
img = Image.open('path_to_image').convert('L')
img_tensor = transform(img).unsqueeze(0)
# Perform inference
with torch.no_grad():
output = model(img_tensor)
predicted_label = torch.argmax(output, dim=1).item()
print(f"Predicted Label: {predicted_label}")
Evaluation Results
Metrics:
The model was evaluated on the MNIST test set, achieving the following results:
- Accuracy: ~98%
- Loss: Cross-entropy loss decreased to a value of approximately 0.15 after 10 epochs.
Performance on Noisy Inputs:
The model was tested on noisy images (e.g., images with added noise or distortions), and the preprocessing steps (Gaussian blur, resizing) helped improve the model’s performance on such inputs.
Limitations
- Noisy Inputs: Although preprocessing helps, very noisy or distorted inputs might still be challenging for the model.
- Generalization: This model is primarily trained on MNIST digits. It may not generalize well to digits from different writing styles or other number systems.
Training Details
Hyperparameters:
- Optimizer: Adam
- Learning Rate: 0.001
- Loss Function: Cross-entropy Loss
- Batch Size: 32
- Epochs: 10
- Data Augmentation: Random rotations and translations
Ethical Considerations
There are no significant ethical concerns related to this model. However, users should be aware that the model is specifically trained on simple MNIST digits and may not perform well in more complex scenarios.
Contact
For any questions or feedback, please reach out to the model author via [[email protected]].
- Downloads last month
- 4