Spaces:
Sleeping
Sleeping
File size: 1,687 Bytes
81a6137 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
# Load the model and feature extractor from Hugging Face
model_name = "immartian/improved_digits_recognition"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Preprocessing function to transform the drawn image into a format the model can recognize
def preprocess_image(image):
# Convert the image into a format suitable for the model
image = Image.fromarray(image).convert('L') # Convert to grayscale
image = image.resize((28, 28)) # Resize to 28x28 pixels
image = image.convert('RGB') # Model expects 3-channel images, so convert to RGB
inputs = feature_extractor(images=image, return_tensors="pt")
return inputs['pixel_values']
# Prediction function to classify the drawn digit
def predict_digit(image):
# Preprocess the input image
inputs = preprocess_image(image)
# Make the prediction
with torch.no_grad():
outputs = model(inputs)
predicted_label = outputs.logits.argmax(-1).item()
return f"Predicted Digit: {predicted_label}"
# Gradio interface for drawing the digit and displaying the prediction
demo = gr.Interface(
fn=predict_digit,
inputs="sketchpad", # Allow users to draw a digit
outputs="text",
title="MNIST Digit Recognition",
description="Draw a digit (0-9) and let the model recognize it!",
live=True # The prediction updates while the user draws
)
# Launch the app
if __name__ == "__main__":
demo.launch()
|