Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
# Load the model and feature extractor from Hugging Face | |
model_name = "immartian/improved_digits_recognition" | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
# Preprocessing function to transform the drawn image into a format the model can recognize | |
def preprocess_image(image): | |
# Convert the image into a format suitable for the model | |
image = Image.fromarray(image).convert('L') # Convert to grayscale | |
image = image.resize((28, 28)) # Resize to 28x28 pixels | |
image = image.convert('RGB') # Model expects 3-channel images, so convert to RGB | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
return inputs['pixel_values'] | |
# Prediction function to classify the drawn digit | |
def predict_digit(image): | |
# Preprocess the input image | |
inputs = preprocess_image(image) | |
# Make the prediction | |
with torch.no_grad(): | |
outputs = model(inputs) | |
predicted_label = outputs.logits.argmax(-1).item() | |
return f"Predicted Digit: {predicted_label}" | |
# Gradio interface for drawing the digit and displaying the prediction | |
demo = gr.Interface( | |
fn=predict_digit, | |
inputs="sketchpad", # Allow users to draw a digit | |
outputs="text", | |
title="MNIST Digit Recognition", | |
description="Draw a digit (0-9) and let the model recognize it!", | |
live=True # The prediction updates while the user draws | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |