The MNIST OCR (Optical Character Recognition) model is a deep learning model trained to recognise and classify handwritten digits from 0 to 9. This model is trained on the MNIST dataset, which consists of 60,000 small square 28ร28 pixel grayscale images of handwritten single digits, making it highly accurate for recognising written, isolated digits in a similar style to those found in the training set.
Install Packages
pip install numpy opencv-python requests pillow transformers tensorflow
Usage
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np
import cv2
import requests
from PIL import Image
from io import BytesIO
from typing import List, Optional
from huggingface_hub import hf_hub_download
import tensorflow as tf
import pickle
class ImageTokenizer:
def __init__(self):
self.unique_pixels = set()
self.pixel_to_token = {}
self.token_to_pixel = {}
def fit(self, images):
for image in images:
self.unique_pixels.update(np.unique(image))
self.pixel_to_token = {pixel: i for i, pixel in enumerate(sorted(self.unique_pixels))}
self.token_to_pixel = {i: pixel for pixel, i in self.pixel_to_token.items()}
def tokenize(self, images):
return np.vectorize(self.pixel_to_token.get)(images)
def detokenize(self, tokens):
return np.vectorize(self.token_to_pixel.get)(tokens)
class MNISTPredictor:
def __init__(self, model_name):
# Download the model and tokenizer files
model_path = hf_hub_download(repo_id=model_name, filename="mnist_model.keras")
tokenizer_path = hf_hub_download(repo_id=model_name, filename="mnist_tokenizer.pkl")
# Load the model and tokenizer
self.model = keras.models.load_model(model_path)
with open(tokenizer_path, 'rb') as tokenizer_file:
self.tokenizer = pickle.load(tokenizer_file)
def extract_features(self, image: Image.Image) -> List[np.ndarray]:
"""Extract features from the image for multiple digits."""
# Convert to grayscale
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
# Apply Gaussian blur
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Apply adaptive thresholding
thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
# Find contours
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
digit_images = []
for contour in contours:
# Filter small contours
if cv2.contourArea(contour) > 50: # Adjust this threshold as needed
x, y, w, h = cv2.boundingRect(contour)
roi = thresh[y:y+h, x:x+w]
resized = cv2.resize(roi, (28, 28), interpolation=cv2.INTER_AREA)
digit_images.append(resized.reshape((28, 28, 1)).astype('float32') / 255)
return digit_images
def predict(self, image: Image.Image) -> Optional[List[int]]:
"""Predict digits in the image."""
try:
digit_images = self.extract_features(image)
tokenized_images = [self.tokenizer.tokenize(img) for img in digit_images]
predictions = self.model.predict(np.array(tokenized_images), verbose=0)
return np.argmax(predictions, axis=1).tolist()
except Exception as e:
print(f"Error during prediction: {e}")
return None
def download_image(url: str) -> Optional[Image.Image]:
"""Download an image from a URL."""
try:
response = requests.get(url)
response.raise_for_status()
return Image.open(BytesIO(response.content))
except Exception as e:
print(f"Error downloading image: {e}")
return None
def save_predictions_to_file(predictions: List[int], output_path: str) -> None:
"""Save predictions to a text file."""
try:
with open(output_path, 'w') as f:
f.write(f"Predicted digits are: {', '.join(map(str, predictions))}\n")
except Exception as e:
print(f"Error saving predictions to file: {e}")
def main(image_url: str, model_name: str, output_path: str) -> None:
try:
predictor = MNISTPredictor(model_name)
# Download image
image = download_image(image_url)
if image is None:
raise Exception("Failed to download image")
print(f"Image downloaded successfully.")
# Predict digits
digits = predictor.predict(image)
if digits is not None:
print(f"Predicted digits are: {digits}")
# Save predictions to file
save_predictions_to_file(digits, output_path)
print(f"Predictions saved to {output_path}")
else:
print("Failed to predict digits.")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
image_url = "https://miro.medium.com/v2/resize:fit:720/format:webp/1*w7pBsjI3t3ZP-4Gdog-JdQ.png"
model_name = "0xnu/mnist-ocr"
output_path = "predictions.txt"
main(image_url, model_name, output_path)
Copyright
(c) 2024 Finbarrs Oketunji. All Rights Reserved.
- Downloads last month
- 71
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support