import gradio as gr import torch import numpy as np from torchvision import transforms from PIL import Image # Load the model using PyTorch model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin" # Define your ImageClassifier model architecture (same as used during training) class ImageClassifier(torch.nn.Module): def __init__(self): super().__init__() self.model = torch.nn.Sequential( torch.nn.Conv2d(1, 32, (3, 3)), torch.nn.ReLU(), torch.nn.Conv2d(32, 64, (3, 3)), torch.nn.ReLU(), torch.nn.Conv2d(64, 64, (3, 3)), torch.nn.ReLU(), torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(), torch.nn.Linear(64, 10) ) def forward(self, x): return self.model(x) # Instantiate the model and load weights model = ImageClassifier() model.load_state_dict(torch.hub.load_state_dict_from_url(model_path)) model.eval() # Function to process sketchpad input def sketchToNumpy(image): # Extract the 'composite' key from the sketchpad input dictionary imArray = image['composite'] # 'composite' contains the drawn image return imArray # Gradio preprocessing and prediction pipeline def predict_digit(image): # Convert the sketchpad input into a PIL Image image = Image.fromarray(image).convert('L') # Convert to grayscale # Preprocess: resize to 28x28 and normalize transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) img_tensor = transform(image).unsqueeze(0) # Add batch dimension # Pass through the model with torch.no_grad(): output = model(img_tensor) predicted_label = torch.argmax(output, dim=1).item() return f"Predicted Label: {predicted_label}" # Create Gradio Interface interface = gr.Interface( fn=lambda x: predict_digit(sketchToNumpy(x)), inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L'), outputs="text", title="Digit Recognizer", description="Draw a digit (0-9) and the model will predict the number!" ) # Launch the app if __name__ == "__main__": interface.launch()