File size: 1,687 Bytes
81a6137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoModelForImageClassification, AutoFeatureExtractor

# Load the model and feature extractor from Hugging Face
model_name = "immartian/improved_digits_recognition"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# Preprocessing function to transform the drawn image into a format the model can recognize
def preprocess_image(image):
    # Convert the image into a format suitable for the model
    image = Image.fromarray(image).convert('L')  # Convert to grayscale
    image = image.resize((28, 28))  # Resize to 28x28 pixels
    image = image.convert('RGB')  # Model expects 3-channel images, so convert to RGB
    inputs = feature_extractor(images=image, return_tensors="pt")
    return inputs['pixel_values']

# Prediction function to classify the drawn digit
def predict_digit(image):
    # Preprocess the input image
    inputs = preprocess_image(image)
    
    # Make the prediction
    with torch.no_grad():
        outputs = model(inputs)
        predicted_label = outputs.logits.argmax(-1).item()

    return f"Predicted Digit: {predicted_label}"

# Gradio interface for drawing the digit and displaying the prediction
demo = gr.Interface(
    fn=predict_digit,
    inputs="sketchpad",  # Allow users to draw a digit
    outputs="text",
    title="MNIST Digit Recognition",
    description="Draw a digit (0-9) and let the model recognize it!",
    live=True  # The prediction updates while the user draws
)

# Launch the app
if __name__ == "__main__":
    demo.launch()