Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from torchvision import transforms | |
from PIL import Image | |
# Load the model using PyTorch | |
model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin" | |
# Define your ImageClassifier model architecture (same as used during training) | |
class ImageClassifier(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = torch.nn.Sequential( | |
torch.nn.Conv2d(1, 32, (3, 3)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(32, 64, (3, 3)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(64, 64, (3, 3)), | |
torch.nn.ReLU(), | |
torch.nn.AdaptiveAvgPool2d((1, 1)), | |
torch.nn.Flatten(), | |
torch.nn.Linear(64, 10) | |
) | |
def forward(self, x): | |
return self.model(x) | |
# Instantiate the model and load weights | |
model = ImageClassifier() | |
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path)) | |
model.eval() | |
# Gradio preprocessing and prediction pipeline | |
def predict_digit(image_dict): | |
# Extract the image from the 'image' key in the dictionary (if it exists) | |
if isinstance(image_dict, dict) and "image" in image_dict: | |
image = image_dict["image"] # Access the image data | |
else: | |
raise ValueError("Invalid input format") | |
# Convert the image (numpy array) to a PIL Image | |
image = Image.fromarray(np.array(image)).convert('L') # Convert to grayscale | |
# Preprocess: resize to 28x28 and normalize | |
transform = transforms.Compose([ | |
transforms.Resize((28, 28)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
img_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
# Pass through the model | |
with torch.no_grad(): | |
output = model(img_tensor) | |
predicted_label = torch.argmax(output, dim=1).item() | |
return f"Predicted Label: {predicted_label}" | |
# Create Gradio Interface | |
interface = gr.Interface( | |
fn=predict_digit, | |
inputs=gr.Sketchpad(), # Sketchpad for users to draw | |
outputs="text", | |
title="Digit Recognizer", | |
description="Draw a digit (0-9) and the model will predict the number!" | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |